This PR calls recoverTemporaryLayout before the XeGPUWgtoSgDistribute & XeGPUBlocking Pass to recover all the temporary operand layout which might be required by the transformation patterns for checks and verification
1851 lines
72 KiB
C++
1851 lines
72 KiB
C++
//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Affine/Utils.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
|
|
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
namespace xegpu {
|
|
#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
|
|
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
|
|
} // namespace xegpu
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
// Retrieve the RangeAttr if it is specified.
|
|
static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
|
|
Operation *parent = op->getParentOfType<scf::IfOp>();
|
|
while (parent) {
|
|
if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
|
|
parent->getAttr("sg_id_range")))
|
|
return attr;
|
|
parent = parent->getParentOfType<scf::IfOp>();
|
|
}
|
|
return {};
|
|
}
|
|
|
|
static std::pair<SmallVector<int64_t>, int>
|
|
getSgShapeAndCount(ArrayRef<int64_t> shape,
|
|
xegpu::DistributeLayoutAttr layout) {
|
|
int count = 1;
|
|
SmallVector<int64_t> sgShape(shape);
|
|
if (layout && layout.isForWorkgroup()) {
|
|
SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
|
|
if (!layout.getEffectiveSgDataAsInt().empty())
|
|
sgShape = layout.getEffectiveSgDataAsInt();
|
|
else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
|
|
sgShape = *maybeDerivedSgData;
|
|
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
|
|
// Clamp distUnit to the original shape to handle cases where data is
|
|
// shared among subgroups, which may cause distUnit to exceed the original
|
|
// shape.
|
|
for (size_t i = 0; i < distUnit.size(); ++i)
|
|
distUnit[i] = std::min(shape[i], distUnit[i]);
|
|
count = computeProduct(shape) / computeProduct(distUnit);
|
|
}
|
|
return std::make_pair(sgShape, count);
|
|
}
|
|
|
|
/// Utility helper for deriving a list of offsets for each sub-TensorDescs
|
|
/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
|
|
/// associated distribute layout attribute, the shape, subgroup id and the
|
|
/// original offsets of the op
|
|
template <
|
|
typename OpType,
|
|
typename = std::enable_if_t<llvm::is_one_of<
|
|
OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
|
|
xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
|
|
static LogicalResult
|
|
genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
|
|
SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
|
|
Location loc = op.getLoc();
|
|
SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
|
|
// not applicable to ops without offsets operands.
|
|
if (origOffsets.empty())
|
|
return failure();
|
|
|
|
// if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
|
|
xegpu::DistributeLayoutAttr layout;
|
|
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
|
|
std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
|
|
layout = op.getLayoutAttr();
|
|
} else {
|
|
layout = op.getDescLayoutAttr();
|
|
}
|
|
|
|
// not applicable to ops without workgroup layout attributes
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
Value sgId =
|
|
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
|
|
|
// verify and adjust the sgId if the range specifier is present
|
|
xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
|
|
if (sgIdRange) {
|
|
int64_t startOfRange = sgIdRange.getStart().getInt();
|
|
int64_t endOfRange = sgIdRange.getEnd().getInt();
|
|
// verify the RangeAttr against the layout attribute
|
|
if (layout.getNumSubgroups() != endOfRange - startOfRange)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "sg_layout size must match the sg_id_range");
|
|
// adjust the sgId if necessary
|
|
if (startOfRange > 0) {
|
|
Value startOfRangeVal =
|
|
arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
|
|
sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
|
|
}
|
|
}
|
|
|
|
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
|
|
// descriptors to be accessed, based on the layout information.
|
|
ArrayRef<int64_t> wgShape = op.getDataShape();
|
|
auto maybeDescOffsets =
|
|
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
|
|
if (failed(maybeDescOffsets))
|
|
return failure();
|
|
|
|
// Compute the final global offsets for each accessed sub-tensor
|
|
// or sub-memory descriptor.
|
|
for (const auto &sgOffsets : *maybeDescOffsets) {
|
|
SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
|
|
rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
|
|
offsetsList.push_back(std::move(newOffsets));
|
|
}
|
|
|
|
// callback(offsetsList);
|
|
return success();
|
|
}
|
|
|
|
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
|
|
/// from a workgroup descriptor. It replaces the offsets and sizes with
|
|
/// appropriate values for the subgroup.
|
|
/// It uses round-robin assignment to distribute the work to the subgroups.
|
|
/// Following create_nd_desc operation:,
|
|
/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
|
|
/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
|
|
/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
|
|
/// is converted to 9 subgroup level operations based on the sg_layout &
|
|
/// sg_data:
|
|
/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
|
|
/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
|
|
/// lane_data = [1, 1]>>
|
|
///
|
|
/// The sg_layout and sg_data attributes are dropped after the pass as they are
|
|
/// no longer needed.
|
|
///
|
|
/// 24x24 matrix distribution example:
|
|
/// sg_layout = [4, 4], sg_data = [2, 2]
|
|
/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
|
|
/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
|
|
///
|
|
/// +------------------------+
|
|
/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
|
|
/// |-----+-----+-----|
|
|
/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
|
|
/// |-----+-----+-----|
|
|
/// | 8x8 | 8x8 | 8x8 |
|
|
/// +------------------------+
|
|
///
|
|
/// Each 8x8 tile is further subdivided among subgroups:
|
|
/// +------------------------+
|
|
/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
|
|
/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
|
|
/// | 2x2 2x2 2x2 2x2 |
|
|
/// | 2x2 2x2 2x2 2x2 |
|
|
/// +------------------------+
|
|
///
|
|
/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
|
|
/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
|
|
|
|
/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
|
|
/// pattern and all the other ops just follow.
|
|
/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
|
|
/// ops in the pass.
|
|
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
|
|
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
MLIRContext *ctx = op.getContext();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
ArrayRef<int64_t> wgShape = tdescTy.getShape();
|
|
Type elemTy = tdescTy.getElementType();
|
|
xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
auto newTdescTy =
|
|
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
|
|
layout.dropSgLayoutAndData());
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto offsets : offsetsList) {
|
|
auto newOp = xegpu::CreateNdDescOp::create(
|
|
rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
|
|
op.getMixedSizes(), op.getMixedStrides());
|
|
|
|
newOps.push_back(newOp);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newOps});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the CreateNdDescOp without offsets to create a
|
|
// subgroup descriptor from a workgroup descriptor
|
|
struct WgToSgCreateNdOpNoOffset
|
|
: public OpConversionPattern<xegpu::CreateNdDescOp> {
|
|
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Check no offsets are specified.
|
|
if (!op.getMixedOffsets().empty())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
MLIRContext *ctx = op.getContext();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
Type elemTy = tdescTy.getElementType();
|
|
ArrayRef<int64_t> wgShape = tdescTy.getShape();
|
|
|
|
SmallVector<int64_t> sgShape;
|
|
int count;
|
|
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
|
|
xegpu::TensorDescType newTdescTy =
|
|
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
|
|
layout.dropSgLayoutAndData());
|
|
|
|
SmallVector<Value> newCreateNdOps(count);
|
|
std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
|
|
return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
|
|
op.getSource(), op.getMixedSizes(),
|
|
op.getMixedStrides());
|
|
});
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms the LoadNdOp to load subgroup data.
|
|
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
|
|
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!op.getMixedOffsets().empty())
|
|
return failure();
|
|
|
|
SmallVector<Value> newLoadOps;
|
|
for (auto src : adaptor.getTensorDesc()) {
|
|
xegpu::TensorDescType tdescTy =
|
|
dyn_cast<xegpu::TensorDescType>(src.getType());
|
|
ArrayRef<int64_t> srcShape = tdescTy.getShape();
|
|
VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
|
|
auto newLoadOp = xegpu::LoadNdOp::create(
|
|
rewriter, op.getLoc(), newResTy, src,
|
|
xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
|
|
newLoadOps.push_back(newLoadOp);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newLoadOps});
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
|
|
/// It creates a StoreNdOp op to store the updated values to the new subgroup
|
|
/// src tensor descriptors.
|
|
struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
|
|
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!op.getMixedOffsets().empty())
|
|
return failure();
|
|
|
|
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
|
|
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the LoadNdOp with explicit offsets to load
|
|
// subgroup data.
|
|
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
|
|
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropSgLayoutAndData();
|
|
SmallVector<Value> newOps;
|
|
for (auto [tdesc, offsets] :
|
|
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
|
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
|
|
VectorType newResTy =
|
|
VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
|
|
auto newOp = xegpu::LoadNdOp::create(
|
|
rewriter, op.getLoc(), newResTy, tdesc, offsets,
|
|
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr(), layout);
|
|
newOps.push_back(newOp);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newOps});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the StoreNdOp with explicit offsets to store
|
|
// subgroup data.
|
|
struct WgToSgStoreNdOpWithOffset
|
|
: public OpConversionPattern<xegpu::StoreNdOp> {
|
|
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropSgLayoutAndData();
|
|
for (auto [v, tdesc, offsets] :
|
|
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
|
|
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
|
|
op.getL1HintAttr(), op.getL2HintAttr(),
|
|
op.getL3HintAttr(), layout);
|
|
}
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
|
|
// subgroup data.
|
|
struct WgToSgPrefetchNdOpWithOffset
|
|
: public OpConversionPattern<xegpu::PrefetchNdOp> {
|
|
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropSgLayoutAndData();
|
|
for (auto [tdesc, offsets] :
|
|
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
|
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
|
|
op.getL1HintAttr(), op.getL2HintAttr(),
|
|
op.getL3HintAttr(), layout);
|
|
}
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
|
|
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
|
|
/// offsets of the new subgroup src tensor descriptors.
|
|
struct WgToSgUpdateNdOffsetOp
|
|
: public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
|
|
using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
llvm::SmallVector<Value> newUpdateTileOffsetOps;
|
|
for (auto tDesc : adaptor.getTensorDesc()) {
|
|
auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
|
|
rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
|
|
op.getConstOffsets());
|
|
newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms the DpasOp to work at subgroup level.
|
|
struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
|
|
using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType resultTy = op.getResult().getType();
|
|
if (resultTy.getRank() != 2)
|
|
return failure();
|
|
|
|
auto layoutCd = op.getLayoutCdAttr();
|
|
auto layoutA = op.getLayoutAAttr();
|
|
auto layoutB = op.getLayoutBAttr();
|
|
if (!layoutCd || !layoutA || !layoutB)
|
|
return failure();
|
|
size_t i = 0;
|
|
SmallVector<Value> newDpasOps;
|
|
for (auto aVec : adaptor.getLhs()) {
|
|
for (auto bVec : adaptor.getRhs()) {
|
|
|
|
llvm::SmallVector<Value> operands({aVec, bVec});
|
|
Value tmpC;
|
|
if (op.getAcc()) {
|
|
tmpC = adaptor.getAcc()[i++];
|
|
operands.push_back(tmpC);
|
|
}
|
|
|
|
ArrayRef<int64_t> aVecShape =
|
|
llvm::cast<VectorType>(aVec.getType()).getShape();
|
|
ArrayRef<int64_t> bVecShape =
|
|
llvm::cast<VectorType>(bVec.getType()).getShape();
|
|
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
|
|
resultTy.getElementType());
|
|
auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
|
|
newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
|
|
newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
|
|
newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
|
|
|
|
newDpasOps.push_back(newDpasOp);
|
|
}
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newDpasOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
|
|
struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
|
|
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
|
if ((offsetSize != 0) || op.getConstOffsetsAttr())
|
|
return failure();
|
|
|
|
for (auto src : adaptor.getTensorDesc())
|
|
xegpu::PrefetchNdOp::create(
|
|
rewriter, op.getLoc(), TypeRange(), src,
|
|
xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern transforms vector.broadcast ops to work at subgroup level.
|
|
struct WgToSgVectorBroadcastOp
|
|
: public OpConversionPattern<vector::BroadcastOp> {
|
|
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
VectorType resultType = op.getResult().getType();
|
|
ArrayRef<int64_t> wgShape = resultType.getShape();
|
|
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
VectorType newResultType =
|
|
VectorType::get(sgShape, resultType.getElementType());
|
|
|
|
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
|
|
return failure();
|
|
|
|
SmallVector<Value> newBroadcastOps;
|
|
for (auto operand : adaptor.getOperands().front()) {
|
|
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
|
|
newResultType, operand);
|
|
|
|
newBroadcastOps.push_back(newBroadcast.getResult());
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms elementwise ops to work at subgroup level.
|
|
struct WgToSgElementwiseOp : public ConversionPattern {
|
|
WgToSgElementwiseOp(MLIRContext *ctx)
|
|
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Only match ops with elementwise trait and single result.
|
|
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
|
|
return failure();
|
|
|
|
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
|
|
assert(resultType && "Expected result to be a VectorType");
|
|
|
|
ArrayRef<int64_t> wgShape = resultType.getShape();
|
|
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
|
|
size_t numVariants = operands.empty() ? 0 : operands.front().size();
|
|
|
|
if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
|
|
return operandVec.size() != numVariants;
|
|
}))
|
|
return failure();
|
|
|
|
SmallVector<Value> newResults;
|
|
VectorType newResultType =
|
|
VectorType::get(sgShape, resultType.getElementType());
|
|
|
|
for (size_t i = 0; i < numVariants; ++i) {
|
|
SmallVector<Value> opOperands;
|
|
for (auto &operandVec : operands)
|
|
opOperands.push_back(operandVec[i]);
|
|
|
|
OperationState state(op->getLoc(), op->getName());
|
|
state.addOperands(opOperands);
|
|
state.addTypes(newResultType);
|
|
state.addAttributes(op->getAttrs());
|
|
Operation *newOp = rewriter.create(state);
|
|
xegpu::removeLayoutAttrs(newOp);
|
|
newResults.push_back(newOp->getResult(0));
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newResults});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// clang-format off
|
|
// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
|
|
// If input_layout and target_layout have identical sg_layout and sg_data,
|
|
// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
|
|
// dropped. For example:
|
|
// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
|
|
// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
|
|
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
|
|
// becomes:
|
|
// #a = #xegpu.layout<inst_data = [16, 16]>
|
|
// #b = #xegpu.layout<inst_data = [8, 16]>
|
|
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
|
|
// (vector<16x16xf32> is determined by sg_data = [16, 16])
|
|
//
|
|
// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
|
|
// For example:
|
|
// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
|
|
// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
|
|
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
|
|
// is lowered to:
|
|
// #a = #xegpu.layout<inst_data = [16, 16]>
|
|
// #b = #xegpu.layout<inst_data = [8, 16]>
|
|
// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
|
|
// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
|
|
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
|
|
// clang-format on
|
|
struct WgToSgConvertLayoutOp
|
|
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
|
|
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto input = op.getInputLayout();
|
|
auto target = op.getTargetLayout();
|
|
|
|
if (!input || !target || !input.isForWorkgroup() ||
|
|
!target.isForWorkgroup())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Input and target layouts must have subgroup layout");
|
|
|
|
SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
|
|
SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
|
|
DenseI32ArrayAttr inputOrder = input.getOrder();
|
|
SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
|
|
SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
|
|
DenseI32ArrayAttr targetOrder = target.getOrder();
|
|
|
|
// TODO: currently we only support for optimal case, where input and
|
|
// output has the same sg_layout and sg_data, so SLM is not involved.
|
|
if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
|
|
inputOrder != targetOrder)
|
|
return failure();
|
|
|
|
input = input.dropSgLayoutAndData();
|
|
target = target.dropSgLayoutAndData();
|
|
|
|
SmallVector<Value> newOps(adaptor.getSource());
|
|
if (input && target) {
|
|
// keep the ConvertLayoutOp for rest fields, e.g., inst_data.
|
|
for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
|
|
auto newOp = xegpu::ConvertLayoutOp::create(
|
|
rewriter, op.getLoc(), src.getType(), src, input, target);
|
|
newOps[i] = newOp;
|
|
}
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Handles UnrealizedConversionCastOp generated during
|
|
// SCFStructuralTypeConversions (step 1). This op may appear as either a
|
|
// target or source materialization for Vector values, e.g.:
|
|
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
|
|
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
|
|
// it could be either 1:N or N:1 cast. In both cases, the pattern
|
|
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
|
|
// for example, the following scf::forOp
|
|
// ```
|
|
// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
|
|
// %n = use(%arg1): vector<128x128xf16>
|
|
// scf.yield %n : vector<128x128xf16>
|
|
// }
|
|
// ```
|
|
// Could be converted to:
|
|
// ```
|
|
// %1 = unrealized_conversion_cast %0
|
|
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
|
|
// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
|
|
// -> (vector<16x16xf16>, vector<16x16xf16) {
|
|
// %m = unrealized_conversion_cast %arg1, %arg2
|
|
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
|
|
// %n = use(%m): vector<128x128xf16>
|
|
// %b = unrealized_conversion_cast %n
|
|
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
|
|
// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
|
|
// }
|
|
// %cast = unrealized_conversion_cast %for:2
|
|
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
|
|
// ```
|
|
// TODO: remove it when context-aware type converter is ready.
|
|
struct UnrealizedConversionCastOpPattern
|
|
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
|
|
using OpConversionPattern<
|
|
mlir::UnrealizedConversionCastOp>::OpConversionPattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
|
|
|
|
auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
|
|
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
|
|
|
|
if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
|
|
!llvm::all_equal(ValueRange(inputs).getTypes()))
|
|
return failure();
|
|
|
|
// Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
|
|
// It is generated by source materialization (e.g., inits to scf forOp).
|
|
// The input values provided by the adaptor should already be distributed,
|
|
// and their types should correspond exactly to the result types of the
|
|
// operation.
|
|
if (op.getNumOperands() == 1 &&
|
|
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
|
|
rewriter.replaceOp(op, inputs);
|
|
return success();
|
|
}
|
|
|
|
// Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
|
|
// It is generated by target materialization (e.g., arguments/results
|
|
// of scf forOp). All input values must have the same vector type, and
|
|
// their shape must be evenly divisible by the output vector's shape
|
|
// (determined by the nature of the workgroup to subgroup distribution).
|
|
// TODO: it is not safe to do such forward, since such N:1 cast could be
|
|
// from others.
|
|
if (op.getNumResults() == 1 &&
|
|
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
|
|
rewriter.replaceOpWithMultiple(op, {inputs});
|
|
return success();
|
|
}
|
|
|
|
return mlir::failure();
|
|
}
|
|
};
|
|
|
|
// This pattern distributes arith.constant op into subgroup-level constants
|
|
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
|
|
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
|
|
auto vecType = dyn_cast<VectorType>(op.getType());
|
|
if (!vecAttr || !vecType)
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> wgShape = vecType.getShape();
|
|
SmallVector<int64_t> sgShape;
|
|
int count;
|
|
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
|
|
|
|
auto newType = VectorType::get(sgShape, vecType.getElementType());
|
|
Location loc = op.getLoc();
|
|
auto eltType = vecType.getElementType();
|
|
|
|
if (vecAttr.isSplat()) {
|
|
// Splat: single value for all subgroups
|
|
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
|
|
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
|
|
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
|
|
rewriter.replaceOp(op, cstOp);
|
|
return success();
|
|
} else if (sgShape == wgShape) { // if the entire vector is shared by all
|
|
// subgroups, don't distribute
|
|
auto newConstOp =
|
|
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
|
|
rewriter.replaceOp(op, newConstOp);
|
|
return success();
|
|
} else {
|
|
// Non-splat constant
|
|
// Only supports 1D & 2D
|
|
// TODO: support other cases that require SLM access
|
|
if (!eltType.isIndex())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Unsupported element type for non-splat constant op.");
|
|
|
|
if (wgShape.size() > 2)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Only 1D & 2D vector constant supported");
|
|
|
|
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
|
|
int64_t rowStride = 0, colStride = 0;
|
|
int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
|
|
int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
|
|
|
|
// Compute colStride and rowStride, and check for constant strides.
|
|
if (cols > 1) {
|
|
colStride = cast<IntegerAttr>(values[1]).getInt() -
|
|
cast<IntegerAttr>(values[0]).getInt();
|
|
}
|
|
if (rows > 1) {
|
|
rowStride = cast<IntegerAttr>(values[cols]).getInt() -
|
|
cast<IntegerAttr>(values[0]).getInt();
|
|
}
|
|
|
|
for (int64_t r = 0; r < rows; ++r) {
|
|
for (int64_t c = 0; c < cols; ++c) {
|
|
int64_t idx = r * cols + c;
|
|
// Check column stride
|
|
if (c > 0 && cols > 1) {
|
|
int64_t prevIdx = r * cols + (c - 1);
|
|
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
|
|
cast<IntegerAttr>(values[prevIdx]).getInt();
|
|
if (diff != colStride)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Non-constant column stride in constant op.");
|
|
}
|
|
// Check row stride
|
|
if (r > 0 && rows > 1) {
|
|
int64_t prevIdx = (r - 1) * cols + c;
|
|
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
|
|
cast<IntegerAttr>(values[prevIdx]).getInt();
|
|
if (diff != rowStride)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Non-constant row stride in constant op.");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create a constant for the base tile.
|
|
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
|
|
// For 1D case, extract the first sgShape[0] elements.
|
|
SmallVector<Attribute> baseTileValues;
|
|
int baseTileCols = sgShape[sgShape.size() - 1];
|
|
int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
|
|
for (int64_t r = 0; r < baseTileRows; ++r) {
|
|
for (int64_t c = 0; c < baseTileCols; ++c) {
|
|
baseTileValues.push_back(values[r * cols + c]);
|
|
}
|
|
}
|
|
|
|
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
|
|
baseTileValues);
|
|
auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
|
|
|
|
// Get subgroup id
|
|
Value sgId =
|
|
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
|
auto sgOffsets =
|
|
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
|
|
if (failed(sgOffsets))
|
|
return failure();
|
|
|
|
SmallVector<Value, 2> strideConsts;
|
|
strideConsts.push_back(
|
|
arith::ConstantIndexOp::create(rewriter, loc, colStride));
|
|
if (rows > 1)
|
|
strideConsts.insert(
|
|
strideConsts.begin(),
|
|
arith::ConstantIndexOp::create(rewriter, loc, rowStride));
|
|
|
|
SmallVector<Value> newConstOps;
|
|
for (auto offsets : *sgOffsets) {
|
|
// Multiply offset with stride, broadcast it and add to baseConstVec
|
|
Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
for (size_t i = 0; i < strideConsts.size(); ++i) {
|
|
Value mul =
|
|
arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
|
|
offsets[i], strideConsts[i]);
|
|
mulOffset = arith::AddIOp::create(
|
|
rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
|
|
}
|
|
// Broadcast to baseConstVec size
|
|
auto bcastOffset = vector::BroadcastOp::create(
|
|
rewriter, loc, baseConstVec.getType(), mulOffset);
|
|
auto finalConst =
|
|
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
|
|
newConstOps.push_back(finalConst);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newConstOps});
|
|
return success();
|
|
}
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the LoadGatherOp with explicit offsets to load
|
|
// subgroup data
|
|
struct WgToSgLoadGatherOpWithOffset
|
|
: public OpConversionPattern<xegpu::LoadGatherOp> {
|
|
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (!op.getOffsets())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
|
|
if (!resultType)
|
|
return failure();
|
|
ArrayRef<int64_t> wgShape = resultType.getShape();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
|
|
// The offsets need to be distributed
|
|
auto offsetsVecType =
|
|
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
|
|
auto maskVecType =
|
|
dyn_cast<VectorType>(adaptor.getMask().front().getType());
|
|
if (!offsetsVecType || !maskVecType ||
|
|
offsetsVecType.getShape() != maskVecType.getShape()) {
|
|
return rewriter.notifyMatchFailure(op,
|
|
"offsets have not been distributed");
|
|
}
|
|
|
|
SmallVector<Value> newLoadOps;
|
|
auto chunkSizeAttr =
|
|
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
|
|
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
|
|
for (auto [offsets, mask] :
|
|
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
|
|
auto newLayout = layout.dropSgLayoutAndData();
|
|
auto newLoadOp = xegpu::LoadGatherOp::create(
|
|
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
|
|
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
|
|
newLayout);
|
|
newLoadOps.push_back(newLoadOp);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newLoadOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms the StoreScatterOp with explicit offsets to store
|
|
// subgroup data
|
|
struct WgToSgStoreScatterOpWithOffset
|
|
: public OpConversionPattern<xegpu::StoreScatterOp> {
|
|
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (!op.getOffsets())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
|
|
if (!valueType)
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
// The offsets need to be distributed
|
|
auto offsetsVecType =
|
|
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
|
|
auto maskVecType =
|
|
dyn_cast<VectorType>(adaptor.getMask().front().getType());
|
|
if (!offsetsVecType || !maskVecType ||
|
|
offsetsVecType.getShape() != maskVecType.getShape()) {
|
|
return rewriter.notifyMatchFailure(op,
|
|
"offsets have not been distributed");
|
|
}
|
|
|
|
auto chunkSizeOpt = op.getChunkSize();
|
|
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
|
|
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
|
|
for (auto [val, offs, mask] : llvm::zip(
|
|
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
|
|
xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
|
|
mask, chunkSizeAttr, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr(),
|
|
layout.dropSgLayoutAndData());
|
|
}
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
|
|
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> wgShape = op.getDataShape();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
|
|
assert(valueTy && "the value type must be vector type!");
|
|
Type elemTy = valueTy.getElementType();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
VectorType newResTy = VectorType::get(sgShape, elemTy);
|
|
SmallVector<Value> newOps;
|
|
for (auto offsets : offsetsList) {
|
|
auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
|
|
op.getMemDesc(), offsets,
|
|
layout.dropSgLayoutAndData());
|
|
newOps.push_back(newOp);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {newOps});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
|
|
using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
|
|
xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
|
|
offsets, layout.dropSgLayoutAndData());
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern distributes the vector.step ops to work at subgroup level
|
|
struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
|
|
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType type = op.getResult().getType();
|
|
auto wgShape = type.getShape();
|
|
std::optional<SmallVector<int64_t>> sgShape =
|
|
getSgShapeAndCount(wgShape, layout).first;
|
|
if (!sgShape)
|
|
return failure();
|
|
|
|
Value sgId =
|
|
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
|
auto sgOffsets =
|
|
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
|
|
if (failed(sgOffsets))
|
|
return failure();
|
|
|
|
VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
|
|
auto steps = vector::StepOp::create(rewriter, loc, newTy);
|
|
SmallVector<Value> newOps;
|
|
for (auto offsets : *sgOffsets) {
|
|
// Broadcast the offset scalar to a vector & add to the base steps
|
|
auto bcastOffset =
|
|
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
|
|
auto finalSteps =
|
|
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
|
|
newOps.push_back(finalSteps);
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms vector.shape_cast ops to work at subgroup level.
|
|
struct WgToSgVectorShapeCastOp
|
|
: public OpConversionPattern<vector::ShapeCastOp> {
|
|
using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
|
|
if (!resultType)
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> wgShape = resultType.getShape();
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
// Check that srcShape and destShape, if they differ, only differ by
|
|
// expand of unit dimensions.
|
|
auto srcType = dyn_cast<VectorType>(op.getSource().getType());
|
|
if (!srcType)
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> srcShape = srcType.getShape();
|
|
|
|
xegpu::DistributeLayoutAttr layoutToDistribute = layout;
|
|
SmallVector<int64_t> expandedUnitDims;
|
|
if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
|
|
xegpu::DistributeLayoutAttr sourceLayout =
|
|
xegpu::getTemporaryLayout(op->getOpOperand(0));
|
|
|
|
auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
|
|
return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
|
|
return isa<vector::BroadcastOp>(user);
|
|
});
|
|
};
|
|
|
|
if (!usedByBroadcastOp(op))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "ShapeCast ops that expand unit dimensions and are used by "
|
|
"non-broadcast operations are not supported.");
|
|
|
|
if (!sourceLayout.isSliceOf(layout))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "The ShapeCast op only expands dimensions, the result layout "
|
|
"must be a slice of the input layout, or vice versa.");
|
|
layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
|
|
layoutToDistribute =
|
|
layoutToDistribute.setUnitDimLayout(expandedUnitDims);
|
|
}
|
|
|
|
SmallVector<int64_t> sgShape =
|
|
getSgShapeAndCount(wgShape, layoutToDistribute).first;
|
|
VectorType newResultType =
|
|
VectorType::get(sgShape, resultType.getElementType());
|
|
|
|
SmallVector<Value> newShapeCastOps;
|
|
for (auto src : adaptor.getSource()) {
|
|
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
|
|
newResultType, src);
|
|
newShapeCastOps.push_back(newShapeCast.getResult());
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Value createAccumulator(ConversionPatternRewriter &rewriter,
|
|
Location loc, VectorType type,
|
|
vector::CombiningKind kind) {
|
|
Type elemTy = type.getElementType();
|
|
|
|
switch (kind) {
|
|
case vector::CombiningKind::ADD:
|
|
case vector::CombiningKind::XOR:
|
|
case vector::CombiningKind::OR:
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
|
|
|
|
case vector::CombiningKind::MUL:
|
|
case vector::CombiningKind::AND:
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
|
|
|
|
case vector::CombiningKind::MINSI:
|
|
// Use max signed int value for signed integer min
|
|
if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
|
|
auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type,
|
|
rewriter.getIntegerAttr(elemTy, maxVal)));
|
|
}
|
|
return nullptr;
|
|
|
|
case vector::CombiningKind::MINUI:
|
|
if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
|
|
auto maxVal = APInt::getMaxValue(intTy.getWidth());
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type,
|
|
rewriter.getIntegerAttr(elemTy, maxVal)));
|
|
}
|
|
return nullptr;
|
|
|
|
case vector::CombiningKind::MAXSI:
|
|
if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
|
|
auto minVal = APInt::getSignedMinValue(intTy.getWidth());
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type,
|
|
rewriter.getIntegerAttr(elemTy, minVal)));
|
|
}
|
|
return nullptr;
|
|
|
|
case vector::CombiningKind::MAXUI:
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
|
|
|
|
case vector::CombiningKind::MINNUMF:
|
|
case vector::CombiningKind::MINIMUMF:
|
|
// Use +infinity for float min operations
|
|
if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
|
|
auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
|
|
}
|
|
return nullptr;
|
|
|
|
case vector::CombiningKind::MAXNUMF:
|
|
case vector::CombiningKind::MAXIMUMF:
|
|
// Use -infinity for float max operations
|
|
if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
|
|
auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc, type,
|
|
DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
|
|
}
|
|
return nullptr;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// This function converts multi-dimensional subgroup indices into a single
|
|
/// linear offset. It's used to calculate memory offsets in SLM for
|
|
/// cross-subgroup reduction coordination.
|
|
///
|
|
/// Parameters:
|
|
/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
|
|
/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
|
|
/// z dims)
|
|
/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
|
|
/// 4x8x2 subgroups)
|
|
///
|
|
/// It uses row-major linearization formula:
|
|
/// offset = sum(sgIds[dim] * stride[dim])
|
|
/// where stride[dim] = product of all sgLayout sizes in dimensions after
|
|
/// 'dim'
|
|
///
|
|
/// Example:
|
|
/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
|
|
/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
|
|
/// - Calculation:
|
|
/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
|
|
/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
|
|
/// * linearizedOffset = 1 + 4 = 5
|
|
///
|
|
/// This gives us a unique linear index for each combination of subgroup
|
|
/// positions in the specified dimensions, which is used for SLM row/column
|
|
/// addressing.
|
|
static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
|
|
Location loc, ArrayRef<Value> sgIds,
|
|
ArrayRef<int64_t> dims,
|
|
ArrayRef<int64_t> sgLayout) {
|
|
Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
int64_t stride = 1;
|
|
|
|
for (int64_t dim : dims) {
|
|
Value dimVal = sgIds[dim];
|
|
Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
|
|
Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
|
|
linearizedOffset =
|
|
arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
|
|
stride *= sgLayout[dim];
|
|
}
|
|
|
|
return linearizedOffset;
|
|
}
|
|
|
|
/// This pattern transforms vector.multi_dim_reduction operations from
|
|
/// workgroup-level to subgroup-level execution with support for multiple
|
|
/// reduction dimensions.
|
|
///
|
|
/// Steps include:
|
|
/// 1. LOCAL REDUCTION :
|
|
/// - Each subgroup performs local reduction on its data slice
|
|
/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
|
|
/// phase
|
|
///
|
|
/// 2. CROSS-SUBGROUP :
|
|
/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
|
|
/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
|
|
/// - If not needed, adds original accumulator and returns local results
|
|
///
|
|
/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
|
|
/// a) SLM Layout Design:
|
|
/// - Rows: subgroups participating in reduction (product of sg_layout in
|
|
/// reduction dims)
|
|
/// - Cols: total result elements across non-reduction dimensions
|
|
///
|
|
/// b) Store Phase:
|
|
/// - Each subgroup stores its local reduction result to SLM
|
|
/// - Row offset: linearized index of subgroup in reduction dimensions
|
|
/// - Col offset: linearized index of subgroup in non-reduction dimensions
|
|
///
|
|
/// c) Load and Final Reduction Phase:
|
|
/// - Each subgroup loads a column of data (all reduction participants for
|
|
/// its position)
|
|
/// - Performs final reduction along the loaded dimension
|
|
/// - Adds original accumulator to get final result
|
|
///
|
|
struct WgToSgMultiDimReductionOp
|
|
: public OpConversionPattern<vector::MultiDimReductionOp> {
|
|
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
VectorType srcType = op.getSourceVectorType();
|
|
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
auto originalSrcShape = srcType.getShape();
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
auto reductionDims = llvm::to_vector(op.getReductionDims());
|
|
|
|
// Get sg_layout and sg_data from the parent layout
|
|
SmallVector<int64_t> sgLayout;
|
|
SmallVector<int64_t> sgData;
|
|
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
|
|
sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
|
|
sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
|
|
} else
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Reduction should have SliceAttr layout");
|
|
|
|
Type elemTy = dstType.getElementType();
|
|
|
|
// Step 1: perform local subgroup reductions with ZERO accumulator
|
|
SmallVector<Value> localReductions;
|
|
SmallVector<int64_t> sgShape =
|
|
getSgShapeAndCount(originalSrcShape, layout).first;
|
|
VectorType newDstType = VectorType::get(sgShape, elemTy);
|
|
for (auto sgSrc : adaptor.getSource()) {
|
|
// Create ZERO accumulator for local reduction
|
|
auto neutralLocalAcc =
|
|
createAccumulator(rewriter, loc, newDstType, op.getKind());
|
|
// Local reduction with ZERO accumulator
|
|
auto localReduce = vector::MultiDimReductionOp::create(
|
|
rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
|
|
reductionDims);
|
|
localReductions.push_back(localReduce.getResult());
|
|
}
|
|
|
|
// Check if cross-subgroup reduction is needed for any reduction dimension
|
|
SmallVector<int64_t> crossSgReductionDims;
|
|
for (int64_t reductionDim : reductionDims) {
|
|
bool needsCrossSubgroupReduction =
|
|
(sgLayout[reductionDim] > 1) &&
|
|
(sgData[reductionDim] < originalSrcShape[reductionDim]);
|
|
|
|
if (needsCrossSubgroupReduction) {
|
|
crossSgReductionDims.push_back(reductionDim);
|
|
}
|
|
}
|
|
|
|
// If no cross-subgroup reduction needed, add accumulator and return
|
|
if (crossSgReductionDims.empty()) {
|
|
SmallVector<Value> results;
|
|
for (auto localResult : localReductions) {
|
|
auto finalResult = vector::makeArithReduction(
|
|
rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
|
|
results.push_back(finalResult);
|
|
}
|
|
rewriter.replaceOpWithMultiple(op, {results});
|
|
return success();
|
|
}
|
|
|
|
// Step 2: cross-subgroup reduction using SLM
|
|
|
|
// Calculate total elements in local result
|
|
int64_t localElements = computeProduct(sgShape);
|
|
|
|
// Shape cast for SLM storage - store as [1, localElements]
|
|
SmallVector<int64_t> storeShape2D = {1, localElements};
|
|
VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
|
|
auto storeShapeCast = vector::ShapeCastOp::create(
|
|
rewriter, loc, storeType2D, localReductions[0]);
|
|
Value storeData = storeShapeCast.getResult();
|
|
|
|
// Calculate SLM shape - rows for sg's in reduction dims, cols for total
|
|
// result elements across all subgroups in non-reduction dimensions
|
|
int64_t totalReductionSubgroups = 1;
|
|
for (int64_t dim : crossSgReductionDims) {
|
|
totalReductionSubgroups *= sgLayout[dim];
|
|
}
|
|
|
|
// Total result elements across all subgroups in non-reduction dimensions
|
|
int64_t totalResultElements =
|
|
localElements * computeProduct(sgLayout) / totalReductionSubgroups;
|
|
|
|
SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
|
|
totalResultElements};
|
|
|
|
// Allocate SLM
|
|
auto bitWidth = elemTy.getIntOrFloatBitWidth();
|
|
auto bytesPerElement = bitWidth / 8;
|
|
int64_t slmElements = slmShape2D[0] * slmShape2D[1];
|
|
auto slmSize = slmElements * bytesPerElement;
|
|
auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
|
|
auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
|
|
|
|
auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
|
|
slmShape2D, elemTy, nullptr);
|
|
auto memDesc =
|
|
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
|
|
|
|
// Step 4: Store local results to SLM
|
|
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
|
|
rewriter.getIndexType(), nullptr);
|
|
|
|
// Convert sgLayout to Values for delinearizeIndex
|
|
SmallVector<Value> sgLayoutValues;
|
|
for (int64_t dim : sgLayout)
|
|
sgLayoutValues.push_back(
|
|
arith::ConstantIndexOp::create(rewriter, loc, dim));
|
|
|
|
auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
|
|
sgLayoutValues);
|
|
if (failed(sgIdsResult))
|
|
return failure();
|
|
SmallVector<Value> sgIds = *sgIdsResult;
|
|
|
|
// Row offset: linearize reduction dimension indices
|
|
Value rowOffsetStore = linearizeSubgroupIndices(
|
|
rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
|
|
|
|
// Column offset: linearize non-reduction dimension indices
|
|
SmallVector<int64_t> nonReductionDims;
|
|
for (size_t i = 0; i < sgLayout.size(); ++i) {
|
|
if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
|
|
nonReductionDims.push_back(static_cast<int64_t>(i));
|
|
}
|
|
}
|
|
|
|
Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
|
|
nonReductionDims, sgLayout);
|
|
|
|
Value localElementsVal =
|
|
arith::ConstantIndexOp::create(rewriter, loc, localElements);
|
|
colOffset =
|
|
arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
|
|
|
|
SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
|
|
|
|
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
|
|
storeOffsets2D, /*layout=*/nullptr);
|
|
|
|
gpu::BarrierOp::create(rewriter, loc);
|
|
|
|
// Step 5: Load from SLM for final reduction
|
|
SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
|
|
VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
|
|
|
|
// Load offsets - each subgroup loads its column based on non-reduction
|
|
// position
|
|
Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
|
|
SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
|
|
|
|
auto loadOp = xegpu::LoadMatrixOp::create(
|
|
rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
|
|
/*layout=*/nullptr);
|
|
|
|
// Step 6: Perform final reduction with ZERO accumulator
|
|
SmallVector<int64_t> finalReductionDims = {0};
|
|
SmallVector<int64_t> finalResultShape = {localElements};
|
|
VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
|
|
|
|
auto neutralFinalAcc =
|
|
createAccumulator(rewriter, loc, finalResultType, op.getKind());
|
|
|
|
auto finalReduce = vector::MultiDimReductionOp::create(
|
|
rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
|
|
neutralFinalAcc, finalReductionDims);
|
|
|
|
// Step 7: Add the original accumulator at the end
|
|
Value originalAcc = adaptor.getAcc()[0];
|
|
Value accToAdd = originalAcc;
|
|
|
|
// Handle shape mismatch by shape casting
|
|
if (originalAcc.getType() != finalReduce.getResult().getType()) {
|
|
auto originalAccType = cast<VectorType>(originalAcc.getType());
|
|
auto finalResultType =
|
|
cast<VectorType>(finalReduce.getResult().getType());
|
|
|
|
// If they have the same number of elements, just shape cast
|
|
if (originalAccType.getNumElements() ==
|
|
finalResultType.getNumElements()) {
|
|
auto shapeCast = vector::ShapeCastOp::create(
|
|
rewriter, loc, finalResultType, originalAcc);
|
|
accToAdd = shapeCast.getResult();
|
|
}
|
|
}
|
|
|
|
auto finalResult = vector::makeArithReduction(
|
|
rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
|
|
|
|
rewriter.replaceOp(op, finalResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This pattern transforms vector.transpose ops to work at subgroup level.
|
|
struct WgToSgVectorTransposeOp
|
|
: public OpConversionPattern<vector::TransposeOp> {
|
|
using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
VectorType resultType = op.getResultVectorType();
|
|
|
|
ArrayRef<int64_t> wgShape = resultType.getShape();
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
// TODO-LayoutRefactor: handle the case using getTemporaryLayout
|
|
xegpu::DistributeLayoutAttr sourceLayout =
|
|
xegpu::getDistributeLayoutAttr(op.getVector());
|
|
if (!sourceLayout || !sourceLayout.isForWorkgroup())
|
|
return failure();
|
|
|
|
SmallVector<int64_t> sourceSgLayout =
|
|
sourceLayout.getEffectiveSgLayoutAsInt();
|
|
SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
|
|
DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
|
|
DenseI32ArrayAttr resultOrder = layout.getOrder();
|
|
|
|
if (!sourceOrder || !resultOrder) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Both source and result must have order attributes");
|
|
}
|
|
|
|
ArrayRef<int64_t> permutation = op.getPermutation();
|
|
size_t permutationSize = permutation.size();
|
|
if (sourceSgLayout.size() != permutationSize ||
|
|
resultSgLayout.size() != permutationSize) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Layouts and permutation must have the same rank");
|
|
}
|
|
|
|
// Check that sgLayout, sgData & order are properly transposed for source
|
|
// and result
|
|
if (!layout.isTransposeOf(sourceLayout, permutation))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Result layout is not a valid transpose of source layout "
|
|
"according to permutation");
|
|
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
VectorType newResultType =
|
|
VectorType::get(sgShape, resultType.getElementType());
|
|
SmallVector<Value> newTransposeOps;
|
|
for (auto src : adaptor.getVector()) {
|
|
auto newTranspose = vector::TransposeOp::create(
|
|
rewriter, op.getLoc(), newResultType, src, permutation);
|
|
newTransposeOps.push_back(newTranspose.getResult());
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newTransposeOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Distribute vector mask ops to work at subgroup level.
|
|
template <typename MaskOpType>
|
|
struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
|
|
using OpConversionPattern<MaskOpType>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
MaskOpType op,
|
|
typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
if (!layout || !layout.isForWorkgroup())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType type = op.getResult().getType();
|
|
auto wgShape = type.getShape();
|
|
|
|
SmallVector<Value> wgMaskDimSizes;
|
|
if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
|
|
for (int64_t maskSize : op.getMaskDimSizes()) {
|
|
wgMaskDimSizes.push_back(
|
|
arith::ConstantIndexOp::create(rewriter, loc, maskSize));
|
|
}
|
|
} else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
|
|
wgMaskDimSizes = llvm::to_vector(op.getOperands());
|
|
}
|
|
|
|
Value sgId =
|
|
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
|
auto sgOffsets =
|
|
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
|
|
if (failed(sgOffsets))
|
|
return failure();
|
|
|
|
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
|
VectorType resultType = VectorType::get(sgShape, type.getElementType());
|
|
|
|
// In each dimension, each subgroup computes its local mask size as:
|
|
// min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
|
|
SmallVector<Value> newCreateMaskOps;
|
|
for (auto offsetSet : *sgOffsets) {
|
|
SmallVector<Value> maskOperands;
|
|
|
|
for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
|
|
Value dimSizeVal =
|
|
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
|
|
Value offset = offsetSet[i];
|
|
Value adjustedMaskSize =
|
|
arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
|
|
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
Value nonNegative =
|
|
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
|
|
Value sgMaskSize =
|
|
arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
|
|
maskOperands.push_back(sgMaskSize);
|
|
}
|
|
|
|
auto newCreateMaskOp =
|
|
vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
|
|
newCreateMaskOps.push_back(newCreateMaskOp.getResult());
|
|
}
|
|
|
|
rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
|
|
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace xegpu {
|
|
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
|
|
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
|
|
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
|
|
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
|
|
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
|
|
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
|
|
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
|
|
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
|
|
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
|
|
WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
|
|
patterns.getContext());
|
|
}
|
|
} // namespace xegpu
|
|
} // namespace mlir
|
|
|
|
namespace {
|
|
struct XeGPUWgToSgDistributePass
|
|
: public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void XeGPUWgToSgDistributePass::runOnOperation() {
|
|
|
|
Operation *op = getOperation();
|
|
if (!xegpu::recoverTemporaryLayouts(op)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Track existing UnrealizedConversionCastOps
|
|
SmallVector<Operation *> existingCastOps;
|
|
getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
|
|
existingCastOps.push_back(castOp.getOperation());
|
|
});
|
|
|
|
{
|
|
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
|
|
// VectorType operands. This first converts such operands to
|
|
// RankedTensorType, propagates the layout attribute into the encoding
|
|
// attribute, and finally converts the RankedTensorType to VectorType based
|
|
// on the encoding.
|
|
|
|
TypeConverter converter;
|
|
converter.addConversion([&](Type type) -> Type { return type; });
|
|
converter.addConversion(
|
|
[&](RankedTensorType type,
|
|
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
|
|
Type elemTy = type.getElementType();
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
|
|
int count;
|
|
SmallVector<int64_t> subShape;
|
|
std::tie(subShape, count) = getSgShapeAndCount(
|
|
shape,
|
|
dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
|
|
|
|
auto newTy = VectorType::get(subShape, elemTy);
|
|
result.append(count, newTy);
|
|
return success();
|
|
});
|
|
|
|
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(),
|
|
converter);
|
|
}
|
|
|
|
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
|
|
// as well as XeGPU, Arith, and Vector operations.
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
ConversionTarget target(*ctx);
|
|
TypeConverter converter;
|
|
converter.addConversion([&](Type type) -> Type { return type; });
|
|
converter.addConversion(
|
|
[&](xegpu::TensorDescType type,
|
|
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
|
|
Type elemTy = type.getElementType();
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
|
|
int count;
|
|
SmallVector<int64_t> subShape;
|
|
xegpu::LayoutAttr layout = type.getLayoutAttr();
|
|
std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
|
|
|
|
if (layout)
|
|
layout = layout.dropSgLayoutAndData();
|
|
|
|
auto newTy = xegpu::TensorDescType::get(
|
|
type.getContext(), subShape, elemTy, type.getEncoding(), layout);
|
|
result.append(count, newTy);
|
|
return success();
|
|
});
|
|
|
|
auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
|
|
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
|
|
return createOp.getType();
|
|
if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
|
|
return loadOp.getTensorDescType();
|
|
if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
|
|
return storeOp.getTensorDescType();
|
|
if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
|
|
return updateOp.getType();
|
|
if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
|
|
return prefetchOp.getTensorDescType();
|
|
return xegpu::TensorDescType();
|
|
};
|
|
|
|
auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
|
|
return !layout || !layout.isForWorkgroup();
|
|
};
|
|
|
|
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
|
|
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
|
|
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
|
|
auto tdescTy = getTensorDescType(op);
|
|
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
|
|
auto layout = op.getLayoutCdAttr();
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
|
|
[=](xegpu::LoadMatrixOp op) -> bool {
|
|
return isLegal(op.getLayoutAttr());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
|
|
[=](xegpu::StoreMatrixOp op) -> bool {
|
|
return isLegal(op.getLayoutAttr());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<arith::ConstantOp>(
|
|
[=](arith::ConstantOp op) -> bool {
|
|
auto vecType = dyn_cast<VectorType>(op.getType());
|
|
if (!vecType)
|
|
return true;
|
|
|
|
auto layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
|
|
vector::TransposeOp, vector::BroadcastOp,
|
|
vector::MultiDimReductionOp,
|
|
vector::ConstantMaskOp, vector::CreateMaskOp>(
|
|
[=](Operation *op) -> bool {
|
|
// Check for either a SliceAttr or LayoutAttr on the result.
|
|
auto layout =
|
|
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
|
|
[=](xegpu::LoadGatherOp op) -> bool {
|
|
auto layout = op.getLayoutAttr();
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
|
|
[=](xegpu::StoreScatterOp op) -> bool {
|
|
auto layout = op.getLayoutAttr();
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
|
|
[=](xegpu::ConvertLayoutOp op) -> bool {
|
|
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
|
|
});
|
|
|
|
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
|
|
[=](Operation *op) -> std::optional<bool> {
|
|
// Only handle elementwise mappable ops
|
|
if (!OpTrait::hasElementwiseMappableTraits(op))
|
|
return true;
|
|
|
|
VectorType resultType =
|
|
dyn_cast<VectorType>(op->getResult(0).getType());
|
|
if (!resultType)
|
|
return true;
|
|
|
|
// Check if all operands are vectors of the same shape
|
|
// TODO: Support other types.
|
|
for (Value operand : op->getOperands()) {
|
|
VectorType operandType = dyn_cast<VectorType>(operand.getType());
|
|
if (!operandType || operandType.getShape() != resultType.getShape()) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
xegpu::DistributeLayoutAttr layout =
|
|
xegpu::getTemporaryLayout(op->getResult(0));
|
|
return isLegal(layout);
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
|
|
[=](UnrealizedConversionCastOp op) {
|
|
return llvm::is_contained(existingCastOps, op.getOperation());
|
|
});
|
|
|
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
|
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
target);
|
|
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|