This PR refactors layout propagation into two distinct components: result/anchor layout setup and source layout inference from the result. For operations that require a specific result layout due to semantic or hardware constraints, the propagation logic explicitly sets up the result or anchor layout. Otherwise, it infers the source layout from the backward-propagated consumer layout. The result or anchor layout may differ from the backward-propagated consumer layout; any such discrepancies are resolved via the existing layout-conflict mechanism. **This PR introduces the following utility functions:** Source layout inference: > inferBroadcastSourceLayout() > inferMultiReductionSourceLayout() > inferBitCastSourceLayout() > inferShapeCastSourceLayout() > inferInsertStridedSliceSourceLayout() Result / anchor layout setup: > setupMultiReductionResultLayout() > setupBitCastResultLayout() > setupInsertStridedSliceResultLayout() > setupLoadMatrixAnchorLayout() > setupStoreMatrixAnchorLayout() > setupLoadGatherAnchorLayout() > setupStoreScatterAnchorLayout() Part of subgroup distribution related code changes are separated and created as PR https://github.com/llvm/llvm-project/pull/179018/changes.
1047 lines
40 KiB
C++
1047 lines
40 KiB
C++
//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains patterns for unrolling XeGPU operations. It follows a
|
|
// similar concept and design as vector unroll patterns, serving as a complement
|
|
// to them.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
|
|
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/DebugLog.h"
|
|
|
|
namespace mlir {
|
|
namespace xegpu {
|
|
#define GEN_PASS_DEF_XEGPUUNROLL
|
|
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
|
|
} // namespace xegpu
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "xegpu-unroll"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
template <typename SourceOp>
|
|
struct UnrollPattern : public OpRewritePattern<SourceOp> {
|
|
UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<SourceOp>(context, benefit), options(options) {}
|
|
|
|
protected:
|
|
/// Return the target shape for the given `op`. Return std::nullopt if the
|
|
/// op shouldn't be or cannot be unrolled.
|
|
std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
|
|
LDBG() << "Get unroll shape for: " << *op;
|
|
|
|
if (options.filterConstraint && failed(options.filterConstraint(op))) {
|
|
LDBG() << "--no filter constraint -> BAIL";
|
|
return std::nullopt;
|
|
}
|
|
|
|
assert(options.nativeShape &&
|
|
"expects the native shape for native shape call back function.");
|
|
auto nativeShape = options.nativeShape(op);
|
|
return nativeShape;
|
|
}
|
|
|
|
SmallVector<Type> getUnrolledTypes(ShapedType type,
|
|
ArrayRef<int64_t> tileShape,
|
|
bool returnSingleType = false) const {
|
|
return options.getUnrolledTypes(type, tileShape, returnSingleType);
|
|
}
|
|
|
|
/// Emulate the the unpack behavior using insert_strided_slice for VectorType
|
|
/// values and unrealized_conversion_cast for TensorDescType values.
|
|
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
|
|
Location loc, PatternRewriter &rewriter) const {
|
|
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
|
|
auto shape = vecTy.getShape();
|
|
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
|
|
}
|
|
|
|
if (isa<xegpu::TensorDescType>(destTy)) {
|
|
auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
|
|
rewriter.getUnitAttr());
|
|
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
|
|
rewriter.getDenseI64ArrayAttr(blockSize));
|
|
auto castOp = UnrealizedConversionCastOp::create(
|
|
rewriter, loc, destTy, srcs,
|
|
ArrayRef<NamedAttribute>({attr, blkAttr}));
|
|
return castOp.getResult(0);
|
|
}
|
|
|
|
llvm_unreachable("Unexpected destTy.");
|
|
return Value();
|
|
}
|
|
|
|
/// Emulate the the pack behavior using extract_strided_slice for VectorType
|
|
/// values and unrealized_conversion_cast for TensorDescType values.
|
|
SmallVector<Value> pack(Value src, TypeRange destTypes,
|
|
ArrayRef<int64_t> blockSize, Location loc,
|
|
PatternRewriter &rewriter) const {
|
|
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
|
|
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
|
|
blockSize);
|
|
}
|
|
|
|
if (isa<xegpu::TensorDescType>(src.getType())) {
|
|
auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
|
|
rewriter.getUnitAttr());
|
|
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
|
|
rewriter.getDenseI64ArrayAttr(blockSize));
|
|
auto castOp = UnrealizedConversionCastOp::create(
|
|
rewriter, loc, destTypes, src,
|
|
ArrayRef<NamedAttribute>({attr, blkAttr}));
|
|
return castOp.getResults();
|
|
}
|
|
|
|
llvm_unreachable("Unexpected src type.");
|
|
return SmallVector<Value>();
|
|
}
|
|
|
|
private:
|
|
const char *const packAttrName = "__xegpu_blocking_pack__";
|
|
const char *const unpackAttrName = "__xegpu_blocking_unpack__";
|
|
const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
|
|
|
|
xegpu::UnrollOptions options;
|
|
};
|
|
|
|
// Generic helper function for unrolling operations with offsets.
|
|
//
|
|
// Iterates over tile offsets within the tensor descriptor shape and calls
|
|
// the provided createOp function for each computed offset. This is used by
|
|
// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
|
|
// have explicit offsets that need to be adjusted for each unrolled tile.
|
|
SmallVector<Value> computeUnrolledOffsets(
|
|
SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
|
|
ArrayRef<int64_t> targetShape,
|
|
const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
|
|
Location loc, PatternRewriter &rewriter) {
|
|
int64_t rank = tdescTy.getRank();
|
|
ArrayRef<int64_t> shape = tdescTy.getShape();
|
|
|
|
auto addi = [&](OpFoldResult a, int64_t b) -> Value {
|
|
std::optional<int64_t> maybeInt = getConstantIntValue(a);
|
|
if (maybeInt) {
|
|
return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
|
|
} else {
|
|
auto aV = llvm::cast<Value>(a);
|
|
auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
|
|
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
|
|
}
|
|
};
|
|
|
|
SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
|
|
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
|
|
auto validIdxes =
|
|
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
|
|
|
|
SmallVector<Value> newOps;
|
|
for (SmallVector<int64_t> offsets :
|
|
StaticTileOffsetRange(shape, targetShape)) {
|
|
|
|
for (auto [idx, oldOff, offset] :
|
|
llvm::zip(validIdxes, oldOffsets, offsets))
|
|
mixedOffsets[idx] = addi(oldOff, offset);
|
|
|
|
auto newOp = createOp(mixedOffsets);
|
|
newOps.push_back(newOp);
|
|
}
|
|
return newOps;
|
|
}
|
|
|
|
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
|
|
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Value> newOps;
|
|
|
|
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
|
|
bool hasOffsets = op.getMixedOffsets().size() != 0;
|
|
if (!hasOffsets) {
|
|
auto newOp = xegpu::CreateNdDescOp::create(
|
|
rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
|
|
op.getMixedStrides());
|
|
newOps.push_back(newOp);
|
|
} else {
|
|
auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
|
|
return xegpu::CreateNdDescOp::create(
|
|
rewriter, loc, newTdescTy, op.getSource(), offsets,
|
|
op.getMixedSizes(), op.getMixedStrides());
|
|
};
|
|
|
|
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
|
|
*targetShape, createOp, loc, rewriter);
|
|
}
|
|
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
|
|
using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto t : convertedTdesc) {
|
|
auto newOp = xegpu::UpdateNdOffsetOp::create(
|
|
rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
|
|
using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropInstData();
|
|
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
|
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
|
|
|
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
|
|
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
|
|
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
if (!hasOffsets) {
|
|
for (auto t : convertedTdesc)
|
|
xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
|
|
xegpu::dropInstDataOnAttrs(op->getAttrs()));
|
|
} else {
|
|
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
|
|
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
|
|
op.getL1HintAttr(), op.getL2HintAttr(),
|
|
op.getL3HintAttr(), layout);
|
|
// return dummy Value to satisfy function's signature
|
|
return nullptr;
|
|
};
|
|
|
|
computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
|
|
createPrefetch, loc, rewriter);
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
|
|
using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = op.getType();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropInstData();
|
|
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
|
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
|
|
|
Type elemTy = tdescTy.getElementType();
|
|
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
|
|
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
|
|
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
|
|
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
SmallVector<Value> newOps;
|
|
|
|
if (!hasOffsets) {
|
|
for (auto t : convertedTdescs) {
|
|
auto newOp =
|
|
xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
|
|
xegpu::dropInstDataOnAttrs(op->getAttrs()));
|
|
newOps.push_back(newOp);
|
|
}
|
|
} else {
|
|
auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
|
|
return xegpu::LoadNdOp::create(
|
|
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
|
|
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr(), layout);
|
|
};
|
|
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
|
|
*targetShape, createLoad, loc, rewriter);
|
|
}
|
|
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
|
|
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = op.getValueType();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropInstData();
|
|
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
|
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
|
|
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
|
|
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
|
if (!hasOffsets) {
|
|
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
|
|
xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
} else {
|
|
size_t valueIndex = 0;
|
|
auto createStore = [&](SmallVector<OpFoldResult> offsets) {
|
|
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
|
|
convertedTdescs[0], offsets,
|
|
op.getL1HintAttr(), op.getL2HintAttr(),
|
|
op.getL3HintAttr(), layout);
|
|
// return dummy Value to satisfy function's signature
|
|
return nullptr;
|
|
};
|
|
|
|
computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
|
|
createStore, loc, rewriter);
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
|
|
using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::DpasOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// expecting every operands is a 2D Vector
|
|
if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
|
|
auto vecTy = dyn_cast<VectorType>(type);
|
|
return !vecTy || vecTy.getRank() != 2;
|
|
}))
|
|
return failure();
|
|
|
|
// A vector of 3 elements should be returned, representing M, K, N
|
|
// respectively.
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape || targetShape->size() != 3)
|
|
return failure();
|
|
auto M = (*targetShape)[0];
|
|
auto K = (*targetShape)[1];
|
|
auto N = (*targetShape)[2];
|
|
|
|
int64_t aBlockSize[2] = {M, K};
|
|
int64_t bBlockSize[2] = {K, N};
|
|
int64_t cBlockSize[2] = {M, N};
|
|
|
|
auto packWrapper = [&](TypedValue<VectorType> val,
|
|
ArrayRef<int64_t> blockSize) {
|
|
VectorType type = val.getType();
|
|
std::optional<SmallVector<int64_t>> grids =
|
|
computeShapeRatio(type.getShape(), blockSize);
|
|
assert(grids && "Expecting grids to be computed.");
|
|
auto numNewOps = computeProduct(*grids);
|
|
if (numNewOps == 1)
|
|
return SmallVector<Value>({val});
|
|
VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
|
|
SmallVector<Type> convertedTypes(numNewOps, newVecTy);
|
|
SmallVector<Value> values =
|
|
pack(val, convertedTypes, blockSize, loc, rewriter);
|
|
return values;
|
|
};
|
|
|
|
auto a = op.getLhs();
|
|
auto b = op.getRhs();
|
|
auto c = op.getAcc();
|
|
|
|
auto aShape = a.getType().getShape();
|
|
auto bShape = b.getType().getShape();
|
|
|
|
SmallVector<Value> aVals, bVals, cVals;
|
|
aVals = packWrapper(a, aBlockSize);
|
|
bVals = packWrapper(b, bBlockSize);
|
|
|
|
if (c)
|
|
cVals = packWrapper(c, cBlockSize);
|
|
|
|
// Skip the operation if every operand has an invalid blocking size (empty)
|
|
// or if the original shape matches the blocking size (size == 1).
|
|
auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
|
|
: SmallVector<ValueRange>({aVals, bVals});
|
|
if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
|
|
llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
|
|
return failure();
|
|
|
|
VectorType resultTy = op.getResult().getType();
|
|
auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
|
|
|
|
int64_t mIters = aShape[0] / M;
|
|
int64_t kIters = aShape[1] / K;
|
|
int64_t nIters = bShape[1] / N;
|
|
|
|
SmallVector<Value> newOps;
|
|
for (int64_t i = 0; i < mIters; ++i) {
|
|
for (int64_t j = 0; j < nIters; ++j) {
|
|
Value tmpC;
|
|
if (c)
|
|
tmpC = cVals[i * nIters + j]; // init with acc
|
|
|
|
for (int64_t k = 0; k < kIters; ++k) {
|
|
Value aVec = aVals[i * kIters + k];
|
|
Value bVec = bVals[k * nIters + j];
|
|
SmallVector<Value> operands({aVec, bVec});
|
|
if (tmpC)
|
|
operands.push_back(tmpC);
|
|
|
|
tmpC =
|
|
xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
|
|
xegpu::dropInstDataOnAttrs(op->getAttrs()));
|
|
}
|
|
newOps.push_back(tmpC);
|
|
}
|
|
}
|
|
Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
|
|
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getType();
|
|
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
|
|
VectorType indiceVecTy = indiceVec.getType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetIndiceShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
|
|
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
|
|
if (originalChunkSize > 1)
|
|
targetIndiceShape.pop_back();
|
|
|
|
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
|
|
SmallVector<Type> convertedIndiceTypes =
|
|
getUnrolledTypes(indiceVecTy, targetIndiceShape);
|
|
SmallVector<Value> convertedIndiceVec =
|
|
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
|
|
|
|
SmallVector<Value> newOps;
|
|
|
|
// More indices is need when chunkSize > 1. Since a big load from one
|
|
// address could be break into multiple small loads.
|
|
if (originalChunkSize > 1) {
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
|
|
for (auto [indice, indiceType] :
|
|
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
// Compute the offset
|
|
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
|
|
i * blockedChunkSize);
|
|
Value incVec =
|
|
vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
|
|
Value offsetIndice =
|
|
arith::AddIOp::create(rewriter, loc, indice, incVec);
|
|
|
|
auto newOp = xegpu::CreateDescOp::create(
|
|
rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
|
|
|
|
newOps.push_back(newOp);
|
|
}
|
|
}
|
|
} else {
|
|
for (auto indice : convertedIndiceVec) {
|
|
auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
|
|
op.getSource(), indice);
|
|
newOps.push_back(newOp);
|
|
}
|
|
}
|
|
|
|
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
|
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
// TODO: handle the unstructure source case (!tdesTy)
|
|
if (!tdescTy || op.getOffsets())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetMaskShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
|
|
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
|
|
|
Type elemTy = tdescTy.getElementType();
|
|
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
|
|
if (originalChunkSize > 1) {
|
|
targetMaskShape.pop_back();
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
|
|
// the mask is reused across the chunk_size dimension
|
|
for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
|
loc, rewriter))
|
|
convertedMasks.append(numNewChunks, mask);
|
|
|
|
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
|
loc, rewriter);
|
|
}
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
|
|
auto newOp = xegpu::LoadGatherOp::create(
|
|
rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
newOps.push_back(newOp);
|
|
}
|
|
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
|
|
/// load).
|
|
/// It unrolls the offsets and mask operands accordingly, and creates multiple
|
|
/// LoadGatherOp with the unrolled operands.
|
|
struct UnrollLoadGatherOpWithOffset
|
|
: public UnrollPattern<xegpu::LoadGatherOp> {
|
|
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
|
|
Value offsets = op.getOffsets();
|
|
Value mask = op.getMask();
|
|
|
|
// Only handle the case where offsets are present (scattered load)
|
|
if (!offsets)
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetMaskShape(*targetShape);
|
|
int64_t chunkSize = 1;
|
|
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
|
|
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
|
|
chunkSize = intAttr.getInt();
|
|
}
|
|
|
|
// Unroll mask and offsets with correct shape
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
|
|
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
|
|
Type elemTy = valueTy.getElementType();
|
|
VectorType newValueTy = VectorType::get(*targetShape, elemTy);
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
SmallVector<Type> convertedOffsetTypes;
|
|
SmallVector<Value> convertedOffsets;
|
|
|
|
if (chunkSize > 1) {
|
|
// For chunked loads, mask and offsets have one less dimension
|
|
targetMaskShape.pop_back();
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = chunkSize / blockedChunkSize;
|
|
chunkSize = blockedChunkSize;
|
|
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
|
|
|
|
SmallVector<Value> convertedMasksBase =
|
|
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
|
|
SmallVector<Value> convertedOffsetsBase =
|
|
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
|
|
|
|
for (auto maskVal : convertedMasksBase)
|
|
convertedMasks.append(numNewChunks, maskVal);
|
|
|
|
for (auto [baseOffset, offsetType] :
|
|
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
|
|
i * blockedChunkSize);
|
|
Value incVec =
|
|
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
|
|
Value offsetVal =
|
|
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
|
|
convertedOffsets.push_back(offsetVal);
|
|
}
|
|
}
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedMasks =
|
|
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
|
|
|
|
convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
|
|
convertedOffsets =
|
|
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
|
|
}
|
|
|
|
auto layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropInstData();
|
|
|
|
SmallVector<Value> newOps;
|
|
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
|
|
auto newOp = xegpu::LoadGatherOp::create(
|
|
rewriter, loc, newValueTy, op.getSource(), o, m,
|
|
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr(), layout);
|
|
newOps.push_back(newOp);
|
|
}
|
|
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
|
|
/// store).
|
|
/// It unrolls the offsets and mask operands accordingly, and creates multiple
|
|
/// StoreScatterOp with the unrolled operands.
|
|
struct UnrollStoreScatterOpWithOffsets
|
|
: public UnrollPattern<xegpu::StoreScatterOp> {
|
|
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
|
Value offsets = op.getOffsets();
|
|
Value mask = op.getMask();
|
|
|
|
// Only handle the case where offsets are present (scattered store)
|
|
if (!offsets)
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
int64_t chunkSize = 1;
|
|
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
|
|
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
|
|
chunkSize = intAttr.getInt();
|
|
}
|
|
|
|
SmallVector<int64_t> targetMaskShape(*targetShape);
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
|
|
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
SmallVector<Type> convertedOffsetTypes;
|
|
SmallVector<Value> convertedOffsets;
|
|
|
|
if (chunkSize > 1) {
|
|
targetMaskShape.pop_back();
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = chunkSize / blockedChunkSize;
|
|
chunkSize = blockedChunkSize;
|
|
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
|
|
|
|
SmallVector<Value> convertedMasksBase =
|
|
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
|
|
SmallVector<Value> convertedOffsetsBase =
|
|
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
|
|
|
|
for (auto maskVal : convertedMasksBase)
|
|
convertedMasks.append(numNewChunks, maskVal);
|
|
|
|
for (auto [baseOffset, offsetType] :
|
|
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
|
|
for (int64_t i = 0; i < numNewChunks; ++i) {
|
|
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
|
|
i * blockedChunkSize);
|
|
Value incVec =
|
|
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
|
|
Value offsetVal =
|
|
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
|
|
convertedOffsets.push_back(offsetVal);
|
|
}
|
|
}
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedMasks =
|
|
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
|
|
|
|
convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
|
|
convertedOffsets =
|
|
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
|
|
}
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
|
|
|
auto layout = op.getLayoutAttr();
|
|
if (layout)
|
|
layout = layout.dropInstData();
|
|
|
|
for (auto [v, o, m] :
|
|
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
|
|
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
|
|
rewriter.getI64IntegerAttr(chunkSize),
|
|
op.getL1HintAttr(), op.getL2HintAttr(),
|
|
op.getL3HintAttr(), layout);
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
|
|
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
// TODO: handle the unstructure source case (!tdesTy)
|
|
if (!tdescTy || op.getOffsets())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
for (auto t : convertedTdesc)
|
|
xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t,
|
|
xegpu::dropInstDataOnAttrs(op->getAttrs()));
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
|
|
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
// TODO: handle the unstructure source case (!tdesTy)
|
|
if (!tdescTy || op.getOffsets())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> targetMaskShape(*targetShape);
|
|
int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
|
|
|
|
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdescs = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<Type> convertedMaskTypes;
|
|
SmallVector<Value> convertedMasks;
|
|
|
|
if (originalChunkSize > 1) {
|
|
targetMaskShape.pop_back();
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
|
|
// the mask is reused across the chunk_size dimension
|
|
for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
|
loc, rewriter))
|
|
convertedMasks.append(numNewChunks, mask);
|
|
} else {
|
|
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
|
|
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
|
|
loc, rewriter);
|
|
}
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
|
|
|
|
for (size_t i = 0; i < convertedValues.size(); ++i) {
|
|
Value v = convertedValues[i];
|
|
Value t = convertedTdescs[i];
|
|
Value m = op.getMask() ? convertedMasks[i] : nullptr;
|
|
xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
|
|
op.getL2HintAttr(), op.getL3HintAttr());
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
|
|
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
|
|
|
if (!tdescTy.isScattered())
|
|
return failure();
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
SmallVector<Type> convertedTdescTypes =
|
|
getUnrolledTypes(tdescTy, *targetShape);
|
|
SmallVector<Value> convertedTdesc = pack(
|
|
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
|
|
|
|
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
|
|
VectorType offsetVecTy = offsetVec.getType();
|
|
SmallVector<Type> convertedOffsetTypes;
|
|
SmallVector<Value> convertedOffsetVec;
|
|
SmallVector<Value> newOps;
|
|
int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
|
|
if (originalChunkSize > 1) {
|
|
auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
|
|
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
|
|
|
|
int64_t blockedChunkSize = targetShape->back();
|
|
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
|
|
// the offset is reused across the chunk_size dimension
|
|
for (auto offset : pack(offsetVec, convertedOffsetTypes,
|
|
targetOffsetShape, loc, rewriter))
|
|
convertedOffsetVec.append(numNewChunks, offset);
|
|
|
|
} else {
|
|
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
|
|
convertedOffsetVec =
|
|
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
|
|
}
|
|
|
|
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
|
|
auto newOp =
|
|
xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
|
|
using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
|
|
assert(valueTy && "the value type must be vector type!");
|
|
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
|
|
return failure();
|
|
|
|
Type elemTy = valueTy.getElementType();
|
|
ArrayRef<int64_t> shape = valueTy.getShape();
|
|
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
|
|
|
|
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
for (SmallVector<int64_t> offsets :
|
|
StaticTileOffsetRange(shape, *targetShape)) {
|
|
auto adds = xegpu::addElementwise(
|
|
rewriter, loc, mixedOffsets,
|
|
getAsIndexOpFoldResult(op.getContext(), offsets));
|
|
offsetsList.push_back(adds);
|
|
}
|
|
|
|
SmallVector<Value> newOps;
|
|
layout = layout.dropInstData();
|
|
for (SmallVector<OpFoldResult> offsets : offsetsList) {
|
|
auto newOp = xegpu::LoadMatrixOp::create(
|
|
rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
|
|
newOps.push_back(newOp);
|
|
}
|
|
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
|
|
rewriter.replaceOp(op, castOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
|
|
using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
|
|
LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
|
if (!targetShape)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
|
|
assert(valueTy && "the value type must be vector type!");
|
|
ArrayRef<int64_t> shape = valueTy.getShape();
|
|
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
|
|
|
|
SmallVector<Type> convertedValTypes =
|
|
getUnrolledTypes(valueTy, *targetShape);
|
|
SmallVector<Value> convertedValues =
|
|
pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
|
|
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
|
for (SmallVector<int64_t> offsets :
|
|
StaticTileOffsetRange(shape, *targetShape)) {
|
|
auto adds = xegpu::addElementwise(
|
|
rewriter, loc, mixedOffsets,
|
|
getAsIndexOpFoldResult(op.getContext(), offsets));
|
|
offsetsList.push_back(adds);
|
|
}
|
|
|
|
for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
|
|
xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
|
|
layout.dropInstData());
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::xegpu::populateXeGPUUnrollPatterns(
|
|
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
|
|
patterns
|
|
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
|
|
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
|
|
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
|
|
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
|
|
UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
|
|
patterns.getContext(), options);
|
|
}
|