//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" #include namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir using namespace mlir; namespace { // Retrieve the RangeAttr if it is specified. static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { Operation *parent = op->getParentOfType(); while (parent) { if (auto attr = llvm::dyn_cast_if_present( parent->getAttr("sg_id_range"))) return attr; parent = parent->getParentOfType(); } return {}; } static std::pair, int> getSgShapeAndCount(ArrayRef shape, xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector sgShape(shape); if (layout && layout.isForWorkgroup()) { SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); if (!layout.getEffectiveSgDataAsInt().empty()) sgShape = layout.getEffectiveSgDataAsInt(); else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original // shape. for (size_t i = 0; i < distUnit.size(); ++i) distUnit[i] = std::min(shape[i], distUnit[i]); count = computeProduct(shape) / computeProduct(distUnit); } return std::make_pair(sgShape, count); } /// Utility helper for deriving a list of offsets for each sub-TensorDescs /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the /// associated distribute layout attribute, the shape, subgroup id and the /// original offsets of the op template < typename OpType, typename = std::enable_if_t::value>> static LogicalResult genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, SmallVector> &offsetsList) { Location loc = op.getLoc(); SmallVector origOffsets = op.getMixedOffsets(); // not applicable to ops without offsets operands. if (origOffsets.empty()) return failure(); // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() xegpu::DistributeLayoutAttr layout; if constexpr (std::is_same_v || std::is_same_v) { layout = op.getLayoutAttr(); } else { layout = op.getDescLayoutAttr(); } // not applicable to ops without workgroup layout attributes if (!layout || !layout.isForWorkgroup()) return failure(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); // verify and adjust the sgId if the range specifier is present xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); if (sgIdRange) { int64_t startOfRange = sgIdRange.getStart().getInt(); int64_t endOfRange = sgIdRange.getEnd().getInt(); // verify the RangeAttr against the layout attribute if (layout.getNumSubgroups() != endOfRange - startOfRange) return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); // adjust the sgId if necessary if (startOfRange > 0) { Value startOfRangeVal = arith::ConstantIndexOp::create(rewriter, loc, startOfRange); sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal); } } // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. ArrayRef wgShape = op.getDataShape(); auto maybeDescOffsets = layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); // Compute the final global offsets for each accessed sub-tensor // or sub-memory descriptor. for (const auto &sgOffsets : *maybeDescOffsets) { SmallVector newOffsets = xegpu::addWithRightAligned( rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); offsetsList.push_back(std::move(newOffsets)); } // callback(offsetsList); return success(); } /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. /// It uses round-robin assignment to distribute the work to the subgroups. /// Following create_nd_desc operation:, /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout> /// is converted to 9 subgroup level operations based on the sg_layout & /// sg_data: /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout> /// /// The sg_layout and sg_data attributes are dropped after the pass as they are /// no longer needed. /// /// 24x24 matrix distribution example: /// sg_layout = [4, 4], sg_data = [2, 2] /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] /// /// +------------------------+ /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across /// |-----+-----+-----| /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down /// |-----+-----+-----| /// | 8x8 | 8x8 | 8x8 | /// +------------------------+ /// /// Each 8x8 tile is further subdivided among subgroups: /// +------------------------+ /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) /// | 2x2 2x2 2x2 2x2 | /// | 2x2 2x2 2x2 2x2 | /// +------------------------+ /// /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. /// The pass currently has entire distribution logic in the WgToSgCreateNdOp /// pattern and all the other ops just follow. /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the /// ops in the pass. struct WgToSgCreateNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); ArrayRef wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); SmallVector newOps; for (auto offsets : offsetsList) { auto newOp = xegpu::CreateNdDescOp::create( rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, op.getMixedSizes(), op.getMixedStrides()); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms the CreateNdDescOp without offsets to create a // subgroup descriptor from a workgroup descriptor struct WgToSgCreateNdOpNoOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check no offsets are specified. if (!op.getMixedOffsets().empty()) return failure(); Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); auto layout = dyn_cast(tdescTy.getLayout()); if (!layout || !layout.isForWorkgroup()) return failure(); Type elemTy = tdescTy.getElementType(); ArrayRef wgShape = tdescTy.getShape(); SmallVector sgShape; int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); xegpu::TensorDescType newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); SmallVector newCreateNdOps(count); std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(), op.getMixedStrides()); }); rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); return success(); } }; /// This pattern transforms the LoadNdOp to load subgroup data. struct WgToSgLoadNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getMixedOffsets().empty()) return failure(); SmallVector newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast(src.getType()); ArrayRef srcShape = tdescTy.getShape(); VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); auto newLoadOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, src, xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs())); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); return mlir::success(); } }; /// This pattern transforms the StoreNdOp to store to a subgroup descriptor /// It creates a StoreNdOp op to store the updated values to the new subgroup /// src tensor descriptors. struct WgToSgStoreNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); rewriter.eraseOp(op); return success(); } }; // This pattern transforms the LoadNdOp with explicit offsets to load // subgroup data. struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (layout) layout = layout.dropSgLayoutAndData(); SmallVector newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { auto tdescTy = dyn_cast(tdesc.getType()); VectorType newResTy = VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms the StoreNdOp with explicit offsets to store // subgroup data. struct WgToSgStoreNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (layout) layout = layout.dropSgLayoutAndData(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout); } rewriter.eraseOp(op); return success(); } }; // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch // subgroup data. struct WgToSgPrefetchNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (layout) layout = layout.dropSgLayoutAndData(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout); } rewriter.eraseOp(op); return success(); } }; /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. struct WgToSgUpdateNdOffsetOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { llvm::SmallVector newUpdateTileOffsetOps; for (auto tDesc : adaptor.getTensorDesc()) { auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create( rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), op.getConstOffsets()); newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); } rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); return success(); } }; /// This pattern transforms the DpasOp to work at subgroup level. struct WgToSgDpasOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType resultTy = op.getResult().getType(); if (resultTy.getRank() != 2) return failure(); auto layoutCd = op.getLayoutCdAttr(); auto layoutA = op.getLayoutAAttr(); auto layoutB = op.getLayoutBAttr(); if (!layoutCd || !layoutA || !layoutB) return failure(); size_t i = 0; SmallVector newDpasOps; for (auto aVec : adaptor.getLhs()) { for (auto bVec : adaptor.getRhs()) { llvm::SmallVector operands({aVec, bVec}); Value tmpC; if (op.getAcc()) { tmpC = adaptor.getAcc()[i++]; operands.push_back(tmpC); } ArrayRef aVecShape = llvm::cast(aVec.getType()).getShape(); ArrayRef bVecShape = llvm::cast(bVec.getType()).getShape(); VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands); newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData()); newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData()); newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData()); newDpasOps.push_back(newDpasOp); } } rewriter.replaceOpWithMultiple(op, {newDpasOps}); return success(); } }; /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. struct WgToSgPrefetchNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int64_t offsetSize = static_cast(op.getOffsets().size()); if ((offsetSize != 0) || op.getConstOffsetsAttr()) return failure(); for (auto src : adaptor.getTensorDesc()) xegpu::PrefetchNdOp::create( rewriter, op.getLoc(), TypeRange(), src, xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs())); rewriter.eraseOp(op); return success(); } }; /// This pattern transforms vector.broadcast ops to work at subgroup level. struct WgToSgVectorBroadcastOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = op.getResult().getType(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(llvm::cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) return failure(); SmallVector newBroadcastOps; for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); newBroadcastOps.push_back(newBroadcast.getResult()); } rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); return success(); } }; // This pattern transforms elementwise ops to work at subgroup level. struct WgToSgElementwiseOp : public ConversionPattern { WgToSgElementwiseOp(MLIRContext *ctx) : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Only match ops with elementwise trait and single result. if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); assert(resultType && "Expected result to be a VectorType"); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(llvm::cast(op->getResult(0))); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; size_t numVariants = operands.empty() ? 0 : operands.front().size(); if (llvm::any_of(operands, [&](const ValueRange &operandVec) { return operandVec.size() != numVariants; })) return failure(); SmallVector newResults; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); for (size_t i = 0; i < numVariants; ++i) { SmallVector opOperands; for (auto &operandVec : operands) opOperands.push_back(operandVec[i]); OperationState state(op->getLoc(), op->getName()); state.addOperands(opOperands); state.addTypes(newResultType); state.addAttributes(op->getAttrs()); Operation *newOp = rewriter.create(state); xegpu::removeLayoutAttrs(newOp); newResults.push_back(newOp->getResult(0)); } rewriter.replaceOpWithMultiple(op, {newResults}); return success(); } }; // clang-format off // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data. // If input_layout and target_layout have identical sg_layout and sg_data, // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields // dropped. For example: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> // becomes: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32> // (vector<16x16xf32> is determined by sg_data = [16, 16]) // // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups. // For example: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> // is lowered to: // #a = #xegpu.layout // #b = #xegpu.layout // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto input = op.getInputLayout(); auto target = op.getTargetLayout(); if (!input || !target || !input.isForWorkgroup() || !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); SmallVector inputSgLayout = input.getEffectiveSgLayoutAsInt(); SmallVector inputSgData = input.getEffectiveSgDataAsInt(); DenseI32ArrayAttr inputOrder = input.getOrder(); SmallVector targetSgLayout = target.getEffectiveSgLayoutAsInt(); SmallVector targetSgData = target.getEffectiveSgDataAsInt(); DenseI32ArrayAttr targetOrder = target.getOrder(); // TODO: currently we only support for optimal case, where input and // output has the same sg_layout and sg_data, so SLM is not involved. if (inputSgLayout != targetSgLayout || inputSgData != targetSgData || inputOrder != targetOrder) return failure(); input = input.dropSgLayoutAndData(); target = target.dropSgLayoutAndData(); SmallVector newOps(adaptor.getSource()); if (input && target) { // keep the ConvertLayoutOp for rest fields, e.g., inst_data. for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { auto newOp = xegpu::ConvertLayoutOp::create( rewriter, op.getLoc(), src.getType(), src, input, target); newOps[i] = newOp; } } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // Handles UnrealizedConversionCastOp generated during // SCFStructuralTypeConversions (step 1). This op may appear as either a // target or source materialization for Vector values, e.g.: // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> // it could be either 1:N or N:1 cast. In both cases, the pattern // simply forwards the inputs to the outputs using 1:1 or 1:N interface. // for example, the following scf::forOp // ``` // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { // %n = use(%arg1): vector<128x128xf16> // scf.yield %n : vector<128x128xf16> // } // ``` // Could be converted to: // ``` // %1 = unrealized_conversion_cast %0 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) // -> (vector<16x16xf16>, vector<16x16xf16) { // %m = unrealized_conversion_cast %arg1, %arg2 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> // %n = use(%m): vector<128x128xf16> // %b = unrealized_conversion_cast %n // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> // } // %cast = unrealized_conversion_cast %for:2 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> // ``` // TODO: remove it when context-aware type converter is ready. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { using OpConversionPattern< mlir::UnrealizedConversionCastOp>::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); auto inputTy = dyn_cast(inputs[0].getType()); auto outputTy = dyn_cast(op->getOpResult(0).getType()); if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || !llvm::all_equal(ValueRange(inputs).getTypes())) return failure(); // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". // It is generated by source materialization (e.g., inits to scf forOp). // The input values provided by the adaptor should already be distributed, // and their types should correspond exactly to the result types of the // operation. if (op.getNumOperands() == 1 && llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { rewriter.replaceOp(op, inputs); return success(); } // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". // It is generated by target materialization (e.g., arguments/results // of scf forOp). All input values must have the same vector type, and // their shape must be evenly divisible by the output vector's shape // (determined by the nature of the workgroup to subgroup distribution). // TODO: it is not safe to do such forward, since such N:1 cast could be // from others. if (op.getNumResults() == 1 && computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { rewriter.replaceOpWithMultiple(op, {inputs}); return success(); } return mlir::failure(); } }; // This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast(op.getValue()); auto vecType = dyn_cast(op.getType()); if (!vecAttr || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); ArrayRef wgShape = vecType.getShape(); SmallVector sgShape; int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); auto newType = VectorType::get(sgShape, vecType.getElementType()); Location loc = op.getLoc(); auto eltType = vecType.getElementType(); if (vecAttr.isSplat()) { // Splat: single value for all subgroups Attribute singleVal = vecAttr.getSplatValue(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); rewriter.replaceOp(op, newConstOp); return success(); } else { // Non-splat constant // Only supports 1D & 2D // TODO: support other cases that require SLM access if (!eltType.isIndex()) return rewriter.notifyMatchFailure( op, "Unsupported element type for non-splat constant op."); if (wgShape.size() > 2) return rewriter.notifyMatchFailure( op, "Only 1D & 2D vector constant supported"); SmallVector values(vecAttr.getValues()); int64_t rowStride = 0, colStride = 0; int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; // Compute colStride and rowStride, and check for constant strides. if (cols > 1) { colStride = cast(values[1]).getInt() - cast(values[0]).getInt(); } if (rows > 1) { rowStride = cast(values[cols]).getInt() - cast(values[0]).getInt(); } for (int64_t r = 0; r < rows; ++r) { for (int64_t c = 0; c < cols; ++c) { int64_t idx = r * cols + c; // Check column stride if (c > 0 && cols > 1) { int64_t prevIdx = r * cols + (c - 1); int64_t diff = cast(values[idx]).getInt() - cast(values[prevIdx]).getInt(); if (diff != colStride) return rewriter.notifyMatchFailure( op, "Non-constant column stride in constant op."); } // Check row stride if (r > 0 && rows > 1) { int64_t prevIdx = (r - 1) * cols + c; int64_t diff = cast(values[idx]).getInt() - cast(values[prevIdx]).getInt(); if (diff != rowStride) return rewriter.notifyMatchFailure( op, "Non-constant row stride in constant op."); } } } // Create a constant for the base tile. // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. // For 1D case, extract the first sgShape[0] elements. SmallVector baseTileValues; int baseTileCols = sgShape[sgShape.size() - 1]; int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; for (int64_t r = 0; r < baseTileRows; ++r) { for (int64_t c = 0; c < baseTileCols; ++c) { baseTileValues.push_back(values[r * cols + c]); } } auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), baseTileValues); auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr); // Get subgroup id Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); SmallVector strideConsts; strideConsts.push_back( arith::ConstantIndexOp::create(rewriter, loc, colStride)); if (rows > 1) strideConsts.insert( strideConsts.begin(), arith::ConstantIndexOp::create(rewriter, loc, rowStride)); SmallVector newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); for (size_t i = 0; i < strideConsts.size(); ++i) { Value mul = arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); mulOffset = arith::AddIOp::create( rewriter, loc, rewriter.getIndexType(), mulOffset, mul); } // Broadcast to baseConstVec size auto bcastOffset = vector::BroadcastOp::create( rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); return success(); } } }; // This pattern transforms the LoadGatherOp with explicit offsets to load // subgroup data struct WgToSgLoadGatherOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getOffsets()) return failure(); Location loc = op.getLoc(); VectorType resultType = dyn_cast(op.getResult().getType()); if (!resultType) return failure(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; // The offsets need to be distributed auto offsetsVecType = dyn_cast(adaptor.getOffsets().front().getType()); auto maskVecType = dyn_cast(adaptor.getMask().front().getType()); if (!offsetsVecType || !maskVecType || offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } SmallVector newLoadOps; auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); return success(); } }; // This pattern transforms the StoreScatterOp with explicit offsets to store // subgroup data struct WgToSgStoreScatterOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getOffsets()) return failure(); Location loc = op.getLoc(); VectorType valueType = dyn_cast(op.getValue().getType()); if (!valueType) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); // The offsets need to be distributed auto offsetsVecType = dyn_cast(adaptor.getOffsets().front().getType()); auto maskVecType = dyn_cast(adaptor.getMask().front().getType()); if (!offsetsVecType || !maskVecType || offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } auto chunkSizeOpt = op.getChunkSize(); int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout.dropSgLayoutAndData()); } rewriter.eraseOp(op); return success(); } }; struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); ArrayRef wgShape = op.getDataShape(); VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResTy = VectorType::get(sgShape, elemTy); SmallVector newOps; for (auto offsets : offsetsList) { auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy, op.getMemDesc(), offsets, layout.dropSgLayoutAndData()); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; struct WgToSgStoreMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(), offsets, layout.dropSgLayoutAndData()); rewriter.eraseOp(op); return success(); } }; // This pattern distributes the vector.step ops to work at subgroup level struct WgToSgVectorStepOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); Location loc = op.getLoc(); VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); std::optional> sgShape = getSgShapeAndCount(wgShape, layout).first; if (!sgShape) return failure(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); auto steps = vector::StepOp::create(rewriter, loc, newTy); SmallVector newOps; for (auto offsets : *sgOffsets) { // Broadcast the offset scalar to a vector & add to the base steps auto bcastOffset = vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); newOps.push_back(finalSteps); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms vector.shape_cast ops to work at subgroup level. struct WgToSgVectorShapeCastOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = dyn_cast(op.getResult().getType()); if (!resultType) return failure(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); // Check that srcShape and destShape, if they differ, only differ by // expand of unit dimensions. auto srcType = dyn_cast(op.getSource().getType()); if (!srcType) return failure(); ArrayRef srcShape = srcType.getShape(); xegpu::DistributeLayoutAttr layoutToDistribute = layout; SmallVector expandedUnitDims; if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) { xegpu::DistributeLayoutAttr sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0)); auto usedByBroadcastOp = [](vector::ShapeCastOp op) { return llvm::all_of(op.getResult().getUsers(), [](Operation *user) { return isa(user); }); }; if (!usedByBroadcastOp(op)) return rewriter.notifyMatchFailure( op, "ShapeCast ops that expand unit dimensions and are used by " "non-broadcast operations are not supported."); if (!sourceLayout.isSliceOf(layout)) return rewriter.notifyMatchFailure( op, "The ShapeCast op only expands dimensions, the result layout " "must be a slice of the input layout, or vice versa."); layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims); layoutToDistribute = layoutToDistribute.setUnitDimLayout(expandedUnitDims); } SmallVector sgShape = getSgShapeAndCount(wgShape, layoutToDistribute).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); SmallVector newShapeCastOps; for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); newShapeCastOps.push_back(newShapeCast.getResult()); } rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); return success(); } }; static Value createAccumulator(ConversionPatternRewriter &rewriter, Location loc, VectorType type, vector::CombiningKind kind) { Type elemTy = type.getElementType(); switch (kind) { case vector::CombiningKind::ADD: case vector::CombiningKind::XOR: case vector::CombiningKind::OR: return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); case vector::CombiningKind::MUL: case vector::CombiningKind::AND: return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy))); case vector::CombiningKind::MINSI: // Use max signed int value for signed integer min if (auto intTy = dyn_cast(elemTy)) { auto maxVal = APInt::getSignedMaxValue(intTy.getWidth()); return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getIntegerAttr(elemTy, maxVal))); } return nullptr; case vector::CombiningKind::MINUI: if (auto intTy = dyn_cast(elemTy)) { auto maxVal = APInt::getMaxValue(intTy.getWidth()); return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getIntegerAttr(elemTy, maxVal))); } return nullptr; case vector::CombiningKind::MAXSI: if (auto intTy = dyn_cast(elemTy)) { auto minVal = APInt::getSignedMinValue(intTy.getWidth()); return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getIntegerAttr(elemTy, minVal))); } return nullptr; case vector::CombiningKind::MAXUI: return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); case vector::CombiningKind::MINNUMF: case vector::CombiningKind::MINIMUMF: // Use +infinity for float min operations if (auto floatTy = dyn_cast(elemTy)) { auto posInf = APFloat::getInf(floatTy.getFloatSemantics()); return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf))); } return nullptr; case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: // Use -infinity for float max operations if (auto floatTy = dyn_cast(elemTy)) { auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true); return arith::ConstantOp::create( rewriter, loc, type, DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf))); } return nullptr; } return nullptr; } /// This function converts multi-dimensional subgroup indices into a single /// linear offset. It's used to calculate memory offsets in SLM for /// cross-subgroup reduction coordination. /// /// Parameters: /// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z]) /// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and /// z dims) /// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means /// 4x8x2 subgroups) /// /// It uses row-major linearization formula: /// offset = sum(sgIds[dim] * stride[dim]) /// where stride[dim] = product of all sgLayout sizes in dimensions after /// 'dim' /// /// Example: /// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions) /// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1) /// - Calculation: /// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1 /// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4 /// * linearizedOffset = 1 + 4 = 5 /// /// This gives us a unique linear index for each combination of subgroup /// positions in the specified dimensions, which is used for SLM row/column /// addressing. static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter, Location loc, ArrayRef sgIds, ArrayRef dims, ArrayRef sgLayout) { Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); int64_t stride = 1; for (int64_t dim : dims) { Value dimVal = sgIds[dim]; Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride); Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal); linearizedOffset = arith::AddIOp::create(rewriter, loc, linearizedOffset, term); stride *= sgLayout[dim]; } return linearizedOffset; } /// This pattern transforms vector.multi_dim_reduction operations from /// workgroup-level to subgroup-level execution with support for multiple /// reduction dimensions. /// /// Steps include: /// 1. LOCAL REDUCTION : /// - Each subgroup performs local reduction on its data slice /// - Uses ZERO accumulator to avoid double-counting during cross-subgroup /// phase /// /// 2. CROSS-SUBGROUP : /// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in /// reduction dims & sgData[reduction dims] < wgData[reduction dims]) /// - If not needed, adds original accumulator and returns local results /// /// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed): /// a) SLM Layout Design: /// - Rows: subgroups participating in reduction (product of sg_layout in /// reduction dims) /// - Cols: total result elements across non-reduction dimensions /// /// b) Store Phase: /// - Each subgroup stores its local reduction result to SLM /// - Row offset: linearized index of subgroup in reduction dimensions /// - Col offset: linearized index of subgroup in non-reduction dimensions /// /// c) Load and Final Reduction Phase: /// - Each subgroup loads a column of data (all reduction participants for /// its position) /// - Performs final reduction along the loaded dimension /// - Adds original accumulator to get final result /// struct WgToSgMultiDimReductionOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType srcType = op.getSourceVectorType(); VectorType dstType = dyn_cast(op.getResult().getType()); if (!dstType) return failure(); auto originalSrcShape = srcType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); auto reductionDims = llvm::to_vector(op.getReductionDims()); // Get sg_layout and sg_data from the parent layout SmallVector sgLayout; SmallVector sgData; if (auto sliceAttr = dyn_cast(layout)) { sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt(); sgData = sliceAttr.getParent().getEffectiveSgDataAsInt(); } else return rewriter.notifyMatchFailure( op, "Reduction should have SliceAttr layout"); Type elemTy = dstType.getElementType(); // Step 1: perform local subgroup reductions with ZERO accumulator SmallVector localReductions; SmallVector sgShape = getSgShapeAndCount(originalSrcShape, layout).first; VectorType newDstType = VectorType::get(sgShape, elemTy); for (auto sgSrc : adaptor.getSource()) { // Create ZERO accumulator for local reduction auto neutralLocalAcc = createAccumulator(rewriter, loc, newDstType, op.getKind()); // Local reduction with ZERO accumulator auto localReduce = vector::MultiDimReductionOp::create( rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc, reductionDims); localReductions.push_back(localReduce.getResult()); } // Check if cross-subgroup reduction is needed for any reduction dimension SmallVector crossSgReductionDims; for (int64_t reductionDim : reductionDims) { bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1) && (sgData[reductionDim] < originalSrcShape[reductionDim]); if (needsCrossSubgroupReduction) { crossSgReductionDims.push_back(reductionDim); } } // If no cross-subgroup reduction needed, add accumulator and return if (crossSgReductionDims.empty()) { SmallVector results; for (auto localResult : localReductions) { auto finalResult = vector::makeArithReduction( rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]); results.push_back(finalResult); } rewriter.replaceOpWithMultiple(op, {results}); return success(); } // Step 2: cross-subgroup reduction using SLM // Calculate total elements in local result int64_t localElements = computeProduct(sgShape); // Shape cast for SLM storage - store as [1, localElements] SmallVector storeShape2D = {1, localElements}; VectorType storeType2D = VectorType::get(storeShape2D, elemTy); auto storeShapeCast = vector::ShapeCastOp::create( rewriter, loc, storeType2D, localReductions[0]); Value storeData = storeShapeCast.getResult(); // Calculate SLM shape - rows for sg's in reduction dims, cols for total // result elements across all subgroups in non-reduction dimensions int64_t totalReductionSubgroups = 1; for (int64_t dim : crossSgReductionDims) { totalReductionSubgroups *= sgLayout[dim]; } // Total result elements across all subgroups in non-reduction dimensions int64_t totalResultElements = localElements * computeProduct(sgLayout) / totalReductionSubgroups; SmallVector slmShape2D = {totalReductionSubgroups, totalResultElements}; // Allocate SLM auto bitWidth = elemTy.getIntOrFloatBitWidth(); auto bytesPerElement = bitWidth / 8; int64_t slmElements = slmShape2D[0] * slmShape2D[1]; auto slmSize = slmElements * bytesPerElement; auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape2D, elemTy, nullptr); auto memDesc = xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); // Step 4: Store local results to SLM auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, rewriter.getIndexType(), nullptr); // Convert sgLayout to Values for delinearizeIndex SmallVector sgLayoutValues; for (int64_t dim : sgLayout) sgLayoutValues.push_back( arith::ConstantIndexOp::create(rewriter, loc, dim)); auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(), sgLayoutValues); if (failed(sgIdsResult)) return failure(); SmallVector sgIds = *sgIdsResult; // Row offset: linearize reduction dimension indices Value rowOffsetStore = linearizeSubgroupIndices( rewriter, loc, sgIds, crossSgReductionDims, sgLayout); // Column offset: linearize non-reduction dimension indices SmallVector nonReductionDims; for (size_t i = 0; i < sgLayout.size(); ++i) { if (!llvm::is_contained(reductionDims, static_cast(i))) { nonReductionDims.push_back(static_cast(i)); } } Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds, nonReductionDims, sgLayout); Value localElementsVal = arith::ConstantIndexOp::create(rewriter, loc, localElements); colOffset = arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal); SmallVector storeOffsets2D = {rowOffsetStore, colOffset}; xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(), storeOffsets2D, /*layout=*/nullptr); gpu::BarrierOp::create(rewriter, loc); // Step 5: Load from SLM for final reduction SmallVector loadShape2D = {totalReductionSubgroups, localElements}; VectorType loadType2D = VectorType::get(loadShape2D, elemTy); // Load offsets - each subgroup loads its column based on non-reduction // position Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector loadOffsets2D = {rowOffsetLoad, colOffset}; auto loadOp = xegpu::LoadMatrixOp::create( rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D, /*layout=*/nullptr); // Step 6: Perform final reduction with ZERO accumulator SmallVector finalReductionDims = {0}; SmallVector finalResultShape = {localElements}; VectorType finalResultType = VectorType::get(finalResultShape, elemTy); auto neutralFinalAcc = createAccumulator(rewriter, loc, finalResultType, op.getKind()); auto finalReduce = vector::MultiDimReductionOp::create( rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(), neutralFinalAcc, finalReductionDims); // Step 7: Add the original accumulator at the end Value originalAcc = adaptor.getAcc()[0]; Value accToAdd = originalAcc; // Handle shape mismatch by shape casting if (originalAcc.getType() != finalReduce.getResult().getType()) { auto originalAccType = cast(originalAcc.getType()); auto finalResultType = cast(finalReduce.getResult().getType()); // If they have the same number of elements, just shape cast if (originalAccType.getNumElements() == finalResultType.getNumElements()) { auto shapeCast = vector::ShapeCastOp::create( rewriter, loc, finalResultType, originalAcc); accToAdd = shapeCast.getResult(); } } auto finalResult = vector::makeArithReduction( rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd); rewriter.replaceOp(op, finalResult); return success(); } }; // This pattern transforms vector.transpose ops to work at subgroup level. struct WgToSgVectorTransposeOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = op.getResultVectorType(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); // TODO-LayoutRefactor: handle the case using getTemporaryLayout xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(op.getVector()); if (!sourceLayout || !sourceLayout.isForWorkgroup()) return failure(); SmallVector sourceSgLayout = sourceLayout.getEffectiveSgLayoutAsInt(); SmallVector resultSgLayout = layout.getEffectiveSgLayoutAsInt(); DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); DenseI32ArrayAttr resultOrder = layout.getOrder(); if (!sourceOrder || !resultOrder) { return rewriter.notifyMatchFailure( op, "Both source and result must have order attributes"); } ArrayRef permutation = op.getPermutation(); size_t permutationSize = permutation.size(); if (sourceSgLayout.size() != permutationSize || resultSgLayout.size() != permutationSize) { return rewriter.notifyMatchFailure( op, "Layouts and permutation must have the same rank"); } // Check that sgLayout, sgData & order are properly transposed for source // and result if (!layout.isTransposeOf(sourceLayout, permutation)) return rewriter.notifyMatchFailure( op, "Result layout is not a valid transpose of source layout " "according to permutation"); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); SmallVector newTransposeOps; for (auto src : adaptor.getVector()) { auto newTranspose = vector::TransposeOp::create( rewriter, op.getLoc(), newResultType, src, permutation); newTransposeOps.push_back(newTranspose.getResult()); } rewriter.replaceOpWithMultiple(op, {newTransposeOps}); return success(); } }; // Distribute vector mask ops to work at subgroup level. template struct WgToSgVectorMaskOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( MaskOpType op, typename OpConversionPattern::OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); if (!layout || !layout.isForWorkgroup()) return failure(); Location loc = op.getLoc(); VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); SmallVector wgMaskDimSizes; if constexpr (std::is_same_v) { for (int64_t maskSize : op.getMaskDimSizes()) { wgMaskDimSizes.push_back( arith::ConstantIndexOp::create(rewriter, loc, maskSize)); } } else if constexpr (std::is_same_v) { wgMaskDimSizes = llvm::to_vector(op.getOperands()); } Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType resultType = VectorType::get(sgShape, type.getElementType()); // In each dimension, each subgroup computes its local mask size as: // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) SmallVector newCreateMaskOps; for (auto offsetSet : *sgOffsets) { SmallVector maskOperands; for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { Value dimSizeVal = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); Value offset = offsetSet[i]; Value adjustedMaskSize = arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value nonNegative = arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); Value sgMaskSize = arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); maskOperands.push_back(sgMaskSize); } auto newCreateMaskOp = vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); newCreateMaskOps.push_back(newCreateMaskOp.getResult()); } rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); return success(); } }; using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp; using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp; } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns .add( patterns.getContext()); } } // namespace xegpu } // namespace mlir namespace { struct XeGPUWgToSgDistributePass : public xegpu::impl::XeGPUWgToSgDistributeBase { void runOnOperation() override; }; } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { Operation *op = getOperation(); if (!xegpu::recoverTemporaryLayouts(op)) { signalPassFailure(); return; } // Track existing UnrealizedConversionCastOps SmallVector existingCastOps; getOperation()->walk([&](UnrealizedConversionCastOp castOp) { existingCastOps.push_back(castOp.getOperation()); }); { // Step 1: Apply SCFStructuralTypeConversions to SCF operations with // VectorType operands. This first converts such operands to // RankedTensorType, propagates the layout attribute into the encoding // attribute, and finally converts the RankedTensorType to VectorType based // on the encoding. TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( [&](RankedTensorType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); int count; SmallVector subShape; std::tie(subShape, count) = getSgShapeAndCount( shape, dyn_cast_if_present(type.getEncoding())); auto newTy = VectorType::get(subShape, elemTy); result.append(count, newTy); return success(); }); xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); } // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, // as well as XeGPU, Arith, and Vector operations. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); ConversionTarget target(*ctx); TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( [&](xegpu::TensorDescType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); int count; SmallVector subShape; xegpu::LayoutAttr layout = type.getLayoutAttr(); std::tie(subShape, count) = getSgShapeAndCount(shape, layout); if (layout) layout = layout.dropSgLayoutAndData(); auto newTy = xegpu::TensorDescType::get( type.getContext(), subShape, elemTy, type.getEncoding(), layout); result.append(count, newTy); return success(); }); auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { if (auto createOp = dyn_cast(op)) return createOp.getType(); if (auto loadOp = dyn_cast(op)) return loadOp.getTensorDescType(); if (auto storeOp = dyn_cast(op)) return storeOp.getTensorDescType(); if (auto updateOp = dyn_cast(op)) return updateOp.getType(); if (auto prefetchOp = dyn_cast(op)) return prefetchOp.getTensorDescType(); return xegpu::TensorDescType(); }; auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp([=](Operation *op) -> bool { auto tdescTy = getTensorDescType(op); auto layout = dyn_cast_if_present(tdescTy.getLayout()); return isLegal(layout); }); target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { auto layout = op.getLayoutCdAttr(); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::LoadMatrixOp op) -> bool { return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( [=](xegpu::StoreMatrixOp op) -> bool { return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( [=](arith::ConstantOp op) -> bool { auto vecType = dyn_cast(op.getType()); if (!vecType) return true; auto layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getTemporaryLayout(dyn_cast(op->getResult(0))); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::LoadGatherOp op) -> bool { auto layout = op.getLayoutAttr(); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::StoreScatterOp op) -> bool { auto layout = op.getLayoutAttr(); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); }); target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { // Only handle elementwise mappable ops if (!OpTrait::hasElementwiseMappableTraits(op)) return true; VectorType resultType = dyn_cast(op->getResult(0).getType()); if (!resultType) return true; // Check if all operands are vectors of the same shape // TODO: Support other types. for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); if (!operandType || operandType.getShape() != resultType.getShape()) { return true; } } xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(op->getResult(0)); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); }