diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index 3592da4c4636..1481859e94a9 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -23,6 +24,7 @@ namespace mlir { namespace xegpu { class TensorDescType; +class DistributeLayoutAttr; class LayoutAttr; class SliceAttr; } // namespace xegpu diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index a94987885c9e..b4d696444cc4 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -175,22 +175,36 @@ def XeGPU_FenceScopeAttr: let assemblyFormat = "$value"; } -def LayoutTrait: AttrInterface<"LayoutTrait"> { +def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { let cppNamespace = "::mlir::xegpu"; let description = [{ Common trait for all XeGPU layouts. }]; let methods = [ + InterfaceMethod<"Check the availability of workgroup level layouts", + "bool", + "isForWorkgroup">, InterfaceMethod<"Get the rank of attribute", "int64_t", "getRank">, + InterfaceMethod<"Get the num of effective subgroups", + "int64_t", + "getNumSubgroups", (ins), [{ + std::optional> sgLayout = llvm::cast(tablegen_opaque_val).getSgLayoutAsInt(); + if (sgLayout.has_value()) + return computeProduct(*sgLayout); + return 0; + }], [{}]>, InterfaceMethod<"Get the SgLayout field of the attribute as integer array", "std::optional>", "getSgLayoutAsInt">, InterfaceMethod<"Get the SgData field of the attribute as integer array", "std::optional>", "getSgDataAsInt">, + InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", + "xegpu::DistributeLayoutAttr", + "dropSgLayoutAndData">, InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional indices based on the effective subgroup layout.}], "FailureOr>", @@ -206,7 +220,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> { ]; } -def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { +def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { let summary = [{ Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor. @@ -328,12 +342,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { ]; let extraClassDeclaration = [{ - bool isWgLayout() { + bool isForWorkgroup() { return getSgLayout() != nullptr; } - bool isSgLayout() { - return !isWgLayout(); + bool isForSubgroup() { + return !isForWorkgroup(); } int64_t getRank() { @@ -393,7 +407,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { } -def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { +def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { let summary = [{Describes the data distribution and sharing among subgroups or work-items.}]; let description = [{ @@ -420,7 +434,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { }]; let parameters = (ins - "xegpu::LayoutTrait": $parent, + "xegpu::DistributeLayoutAttr": $parent, "DenseI64ArrayAttr": $dims ); @@ -438,16 +452,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { return parent.getOrder(); } - bool isWgLayout() const { + bool isForWorkgroup() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - return parent.isWgLayout(); + return parent.isForWorkgroup(); } - bool isSgLayout() const { + bool isForSubgroup() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - return parent.isSgLayout(); + return parent.isForSubgroup(); } /// Returns the SgLayout of the attribute, computed by applying @@ -474,6 +488,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { return std::nullopt; } + SliceAttr dropSgLayoutAndData() { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + parent = parent.dropSgLayoutAndData(); + return SliceAttr::get(getContext(), parent, attr.getDims()); + } + + SliceAttr dropInstData() { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + parent = parent.dropInstData(); + return SliceAttr::get(getContext(), parent, attr.getDims()); + } + /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr /// #xegpu.slice<#xegpu.slice<#xegpu.layout, dims = [0]>, dims = [0]> /// it will coalese two slice operations and return a simplified SliceAttr diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index eb54d6887681..ab471a1f33ef 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -232,6 +232,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return static_cast(MemorySpace::Global); } + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getType().getLayout()); + } + + ArrayRef getDataShape() { + return getTensorDescShape(); + } + }]; } @@ -262,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDataShape() { + return getTensorDescType().getShape(); + } + }]; let assemblyFormat = [{ @@ -343,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDataShape() { + return getTensorDescType().getShape(); + } + + }]; let assemblyFormat = [{ @@ -417,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDataShape() { + return getTensorDescType().getShape(); + } + }]; let assemblyFormat = [{ @@ -640,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { xegpu::TensorDescType getTensorDescType() { return dyn_cast(getSourceType()); } + }]; let assemblyFormat = [{ @@ -1150,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let results = (outs XeGPU_ValueType:$res); let assemblyFormat = [{ @@ -1175,12 +1236,16 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let builders = [ OpBuilder<(ins "Type":$res, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttr": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } + + ArrayRef getDataShape() { + return getRes().getType().getShape(); + } }]; let hasVerifier = 1; @@ -1194,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands)}]; @@ -1213,12 +1278,17 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, }]; let builders = [ OpBuilder<(ins "Value" : $data, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttr": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } + + ArrayRef getDataShape() { + return getData().getType().getShape(); + } + }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index db8608c6d20b..b2b2d3ab8523 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_ #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" namespace mlir { class VectorType; @@ -128,6 +129,20 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op, /// if no GPU module parent or XeVM target attribute exists. std::optional getChipStr(Operation *op); +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef lhs, + ArrayRef rhs); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 8ea8cb1f4597..a2d708be0e93 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -271,7 +271,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId) { // delinearizeSubgroupId is only available for // workgroup-level layout attribute - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); // TODO: handle order attribute @@ -290,12 +290,13 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by LayoutAttr. +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// LayoutAttr. FailureOr>> LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); SmallVector sgLayout = getSgLayoutAsInt().value(); @@ -322,7 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, //===----------------------------------------------------------------------===// LogicalResult SliceAttr::verify(llvm::function_ref emitError, - xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) { + xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) { if (!parent || !dims) return emitError() << "expected parent layout and dims attribute"; @@ -340,7 +341,7 @@ SliceAttr::verify(llvm::function_ref emitError, } SliceAttr SliceAttr::flatten() const { - xegpu::LayoutTrait parent = getParent(); + xegpu::DistributeLayoutAttr parent = getParent(); SmallVector slicedDims({getDims()}); while (auto sliceAttr = dyn_cast(parent)) { @@ -375,13 +376,14 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return parent.delinearizeSubgroupId(builder, loc, linearId); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by SliceAttr. +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// SliceAttr. FailureOr>> SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { assert(getRank() == static_cast(shape.size()) && "invalid shape."); - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); SmallVector sgLayout = getSgLayoutAsInt().value(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 906c71d8b8da..c8d180b973f0 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -938,8 +938,8 @@ LogicalResult ConvertLayoutOp::verify() { // both input and target layouts should be WgLayout or SgLayout at the same // time. - if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) && - (!srcLayout.isSgLayout() || !resLayout.isSgLayout())) + if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && + (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) return emitOpError("expected input layout and target layout be WgLayout or " "SgLayout at the same time."); @@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, TypedValue memDesc, llvm::ArrayRef offsets, - LayoutTrait layout) { + DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() { void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, TypedValue memDesc, llvm::ArrayRef offsets, - LayoutTrait layout) { + DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index d82c541f3135..b3144e4c1e55 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -141,7 +141,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { value = (Value)operandOrResult; xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); - if (layout && layout.isSgLayout()) { + if (layout && layout.isForSubgroup()) { if (auto inst_data = layout.getInstData()) return llvm::to_vector_of(inst_data.asArrayRef()); @@ -205,12 +205,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); - return layout && layout.isWgLayout(); + return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); - return layout && layout.isWgLayout(); + return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { LDBG() << "skip unrolling for op with workgroup level layout: " << *op; @@ -272,7 +272,7 @@ void XeGPUBlockingPass::runOnOperation() { auto layout = llvm::dyn_cast_if_present(type.getEncoding()); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -289,7 +289,7 @@ void XeGPUBlockingPass::runOnOperation() { ArrayRef shape = type.getShape(); xegpu::LayoutAttr layout = type.getLayoutAttr(); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8f1208e77ca5..93b4efcd125e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,38 +34,29 @@ using namespace mlir; namespace { -// Check if there is sg id range attached to the scf.if op. -static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, - int64_t &endOfRange) { - Operation *parent = op->getParentOp(); - // Find the outermost scf::IfOp with xegpu.sg_id_range. +// Retrieve the RangeAttr if it is specified. +static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { + Operation *parent = op->getParentOfType(); while (parent) { - if (auto ifOp = dyn_cast(parent)) { - if (auto attr = llvm::dyn_cast_or_null( - ifOp->getAttr("sg_id_range"))) { - startOfRange = attr.getStart().getInt(); - endOfRange = attr.getEnd().getInt(); - break; - } - } - parent = parent->getParentOp(); + if (auto attr = llvm::dyn_cast_if_present( + parent->getAttr("sg_id_range"))) + return attr; + parent = parent->getParentOfType(); } - // Return false if startOfRange is 0 - return (startOfRange > 0 && endOfRange > startOfRange); + return {}; } static std::pair, int> -getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { +getSgShapeAndCount(ArrayRef shape, + xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector sgShape(shape); - - if (layout && layout.isWgLayout()) { - DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); - auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - else - sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + if (layout && layout.isForWorkgroup()) { + SmallVector sgLayout = layout.getSgLayoutAsInt().value(); + if (auto maybeSgData = layout.getSgDataAsInt()) + sgShape = *maybeSgData; + else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) + sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original @@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { return std::make_pair(sgShape, count); } +/// Utility helper for deriving a list of offsets for each sub-TensorDescs +/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the +/// associated distribute layout attribute, the shape, subgroup id and the +/// original offsets of the op +template < + typename OpType, + typename = std::enable_if_t::value>> +static LogicalResult +genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, + SmallVector> &offsetsList) { + Location loc = op.getLoc(); + SmallVector origOffsets = op.getMixedOffsets(); + // not applicable to ops without offsets operands. + if (origOffsets.empty()) + return failure(); + + // not applicable to ops without workgroup layout attributes + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); + + // verify and adjust the sgId if the range specifier is present + xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); + if (sgIdRange) { + int64_t startOfRange = sgIdRange.getStart().getInt(); + int64_t endOfRange = sgIdRange.getEnd().getInt(); + // verify the RangeAttr against the layout attribute + if (layout.getNumSubgroups() != endOfRange - startOfRange) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + // adjust the sgId if necessary + if (startOfRange > 0) { + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + sgId = rewriter.create(loc, sgId, startOfRangeVal); + } + } + + // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory + // descriptors to be accessed, based on the layout information. + ArrayRef wgShape = op.getDataShape(); + auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeDescOffsets)) + return failure(); + + // Compute the final global offsets for each accessed sub-tensor + // or sub-memory descriptor. + for (const auto &sgOffsets : *maybeDescOffsets) { + SmallVector newOffsets = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); + offsetsList.push_back(std::move(newOffsets)); + } + + // callback(offsetsList); + return success(); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -128,79 +180,30 @@ struct WgToSgCreateNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - // Ensure that the op has explicit offsets specified (either dynamic or - // constant). - if (op.getMixedOffsets().empty()) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); - auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout) - return failure(); - Type elemTy = tdescTy.getElementType(); ArrayRef wgShape = tdescTy.getShape(); - // sgLayout must be present for workgroup-level distribution. - SmallVector sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - else - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - - // Get the subgroup ID - Value linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = - isSgIdRangeSpecified(op, startOfRange, endOfRange); - - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - // Subtract startOfRange from the original subgroup id to get - // the adjusted sg id - Value startOfRangeVal = - arith::ConstantIndexOp::create(rewriter, loc, startOfRange); - linearSgId = - rewriter.createOrFold(loc, linearSgId, startOfRangeVal); - } - - auto maybeTdescOffsets = - layout.getOffsets(rewriter, loc, linearSgId, wgShape); - if (failed(maybeTdescOffsets)) - return failure(); - + Type elemTy = tdescTy.getElementType(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; - xegpu::TensorDescType newTdescTy = + auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); - SmallVector newCreateNdOps; - SmallVector origOffsets = op.getMixedOffsets(); - - for (auto tdescOffsets : *maybeTdescOffsets) { - SmallVector sgOffsets; - size_t rank = tdescOffsets.size(); - for (size_t i = 0; i < rank; i++) { - size_t idx = origOffsets.size() - rank + i; - Value add = rewriter.createOrFold( - loc, tdescOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); - sgOffsets.push_back(add); - } - + SmallVector newOps; + for (auto offsets : offsetsList) { auto newOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), sgOffsets, + rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, op.getMixedSizes(), op.getMixedStrides()); - newCreateNdOps.push_back(newOp); + + newOps.push_back(newOp); } - rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); + rewriter.replaceOpWithMultiple(op, {newOps}); + return success(); } }; @@ -223,7 +226,7 @@ struct WgToSgCreateNdOpNoOffset MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout || !layout.isWgLayout()) + if (!layout || !layout.isForWorkgroup()) return failure(); Type elemTy = tdescTy.getElementType(); @@ -254,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newLoadOps; - - int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); + SmallVector newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast(src.getType()); @@ -282,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) @@ -296,100 +295,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern { } }; -// Utility function to compute global offsets for subgroup operations. -// Returns a vector of new offsets for each subgroup, given the original op's -// offsets and subgroup relative offsets. -static SmallVector> -computeOffsets(Operation *op, ArrayRef> sgOffsetsList, - ArrayRef origOffsets, - ConversionPatternRewriter &rewriter) { - SmallVector> finalOffsets; - Location loc = op->getLoc(); - for (const auto &sgOffsets : sgOffsetsList) { - SmallVector newOffsets; - size_t rank = sgOffsets.size(); - for (size_t i = 0; i < rank; i++) { - size_t idx = origOffsets.size() - rank + i; - Value add = rewriter.createOrFold( - loc, sgOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); - newOffsets.push_back(add); - } - finalOffsets.push_back(std::move(newOffsets)); - } - return finalOffsets; -} - -// Utility function to get sgShape, sgOffsetList for a given -// op. -template -LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor, - ConversionPatternRewriter &rewriter, - SmallVector &sgShape, - SmallVector> &sgOffsetList) { - int64_t offsetSize = static_cast(op.getOffsets().size()); - if (offsetSize == 0 && (!op.getConstOffsetsAttr())) - return failure(); - - Location loc = op.getLoc(); - Value tdesc = op.getTensorDesc(); - auto tdescTy = dyn_cast(tdesc.getType()); - if (!tdescTy) - return failure(); - auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout) - return failure(); - - SmallVector sgLayout; - auto sgLayoutAttr = layout.getSgLayout(); - if (!sgLayoutAttr) - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - - ArrayRef wgShape = tdescTy.getShape(); - int count; - std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - - // Get the subgroup ID - Value linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); - - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - Value startOfRangeVal = - rewriter.create(loc, startOfRange); - linearSgId = - rewriter.createOrFold(loc, linearSgId, startOfRangeVal); - } - - auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape); - if (failed(sgOffsets)) - return failure(); - - sgOffsetList = *sgOffsets; - return success(); -} - -template -SmallVector getOffsets(OpTy op, - ConversionPatternRewriter &rewriter) { - SmallVector origOffsets; - if (auto constOffsets = op.getConstOffsetsAttr()) { - for (auto attr : constOffsets.asArrayRef()) - origOffsets.push_back(rewriter.getIndexAttr(attr)); - } - for (auto v : op.getOffsets()) - origOffsets.push_back(v); - return origOffsets; -} - // This pattern transforms the LoadNdOp with explicit offsets to load // subgroup data. struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { @@ -398,33 +303,24 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - SmallVector newLoadOps; - for (auto [offsets, tdesc] : - llvm::zip(finalOffsets, adaptor.getTensorDesc())) { - VectorType newResTy = VectorType::get( - sgShape, - dyn_cast(tdesc.getType()).getElementType()); - auto newLoadOp = rewriter.create( - op.getLoc(), newResTy, tdesc, offsets, - /*packed=*/nullptr, - /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); - newLoadOps.push_back(newLoadOp); + SmallVector newOps; + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + auto tdescTy = dyn_cast(tdesc.getType()); + VectorType newResTy = + VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); + auto newOp = xegpu::LoadNdOp::create( + rewriter, op.getLoc(), newResTy, tdesc, offsets, + /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + newOps.push_back(newOp); } - rewriter.replaceOpWithMultiple(op, {newLoadOps}); + rewriter.replaceOpWithMultiple(op, {newOps}); + return success(); } }; @@ -437,27 +333,18 @@ struct WgToSgStoreNdOpWithOffset LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - for (auto [offsets, tdesc, value] : - llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) { - rewriter.create(op.getLoc(), value, tdesc, offsets, + for (auto [v, tdesc, offsets] : + llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { + rewriter.create(op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); + return success(); } }; @@ -470,27 +357,18 @@ struct WgToSgPrefetchNdOpWithOffset LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - for (auto [offsets, tdesc] : - llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { rewriter.create( op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); + return success(); } }; @@ -736,7 +614,8 @@ struct WgToSgConvertLayoutOp xegpu::LayoutAttr input = op.getInputLayout(); xegpu::LayoutAttr target = op.getTargetLayout(); - if (!input || !target || !input.isWgLayout() || !target.isWgLayout()) + if (!input || !target || !input.isForWorkgroup() || + !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); @@ -884,6 +763,56 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +struct WgToSgLoadMatrixOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + ArrayRef wgShape = op.getDataShape(); + VectorType valueTy = op.getRes().getType(); + Type elemTy = valueTy.getElementType(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResTy = VectorType::get(sgShape, elemTy); + SmallVector newOps; + for (auto offsets : offsetsList) { + auto newOp = rewriter.create( + op.getLoc(), newResTy, op.getMemDesc(), offsets, + layout.dropSgLayoutAndData()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +struct WgToSgStoreMatrixOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) + rewriter.create(op.getLoc(), v, op.getMemDesc(), + offsets, + layout.dropSgLayoutAndData()); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace namespace mlir { @@ -895,7 +824,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, - WgToSgArithConstantOp>(patterns.getContext()); + WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -985,8 +915,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); + target.addDynamicallyLegalOp( + [=](xegpu::LoadMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp( + [=](xegpu::StoreMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( @@ -1015,6 +950,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](vector::BroadcastOp op) -> bool { + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 19eedbac0f76..6835f64ad8ef 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -40,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { auto layout = llvm::dyn_cast_if_present(tdescTy.getLayout()); // It only works for subgroup level layout, which only has lane_layout // and lane_data, and is to distribute a SIMD code into SIMT code. - if (!layout || !layout.isSgLayout()) + if (!layout || !layout.isForSubgroup()) return failure(); SmallVector laneData(layout.getLaneData().asArrayRef()); @@ -424,3 +425,31 @@ std::optional xegpu::getChipStr(Operation *op) { return std::nullopt; } + +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector +xegpu::addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef lhs, + ArrayRef rhs) { + // ensure a is longer than b + ArrayRef a = lhs.size() >= rhs.size() ? lhs : rhs; + ArrayRef b = lhs.size() >= rhs.size() ? rhs : lhs; + SmallVector results(a.take_front(a.size() - b.size())); + a = a.slice(a.size() - b.size()); + for (auto [l, r] : llvm::zip(a, b)) { + auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); + auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); + results.push_back(builder.createOrFold(loc, lval, rval)); + } + return results; + return {}; +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 07a0b86223c3..32157a7911f6 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -263,4 +263,62 @@ gpu.module @test_distribution { } {sg_id_range = #xegpu.range<[3, 19]>} gpu.return } + + // CHECK-LABEL: distribute_load_matrix + // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) { + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_1:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index + //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]] + //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> + %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> + gpu.return + } + + //CHECK-LABEL: distribute_store_matrix + //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) { + //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_2:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index + //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<64x128xf32> + %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> + gpu.return + } } diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 58962714b786..200323c7a4e5 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -82,7 +82,7 @@ struct TestXeGPUUnrollingPatterns if (auto layout = tdescTy.getLayoutAttr()) { auto inst_data = layout.getInstData(); - if (inst_data && layout.isSgLayout()) + if (inst_data && layout.isForSubgroup()) return SmallVector(inst_data.asArrayRef().begin(), inst_data.asArrayRef().end()); } @@ -156,8 +156,8 @@ struct TestXeGPUUnrollingPatterns #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") // Test pattern for distributing vector::StepOp from workgroup to subgroup. -// Validates LayoutTrait interfaces for offset computation abstraction between -// LayoutAttr and SliceAttr. +// Validates DistributeLayoutAttr interfaces for offset computation +// abstraction between LayoutAttr and SliceAttr. class TestStepOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -239,7 +239,7 @@ struct TestXeGPULayoutInterface ConversionTarget target(*ctx); auto isLegal = [&](xegpu::SliceAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp(