In the LLVM style guide, we prefer not using braced initializer lists to call a constructor. Also, we prefer using an equal before the open curly brace if we use a braced initializer list when initializing a variable. See https://llvm.org/docs/CodingStandards.html#do-not-use-braced-initializer-lists-to-call-a-constructor for more details. The style guide does not explain the reason well. There is an article from abseil, which mentions few benefits. E.g., we can avoid the most vexing parse, etc. See https://abseil.io/tips/88 for more details. Signed-off-by: hanhanW <hanhan0912@gmail.com>
336 lines
12 KiB
C++
336 lines
12 KiB
C++
//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements lowering of vector operations to XeGPU dialect ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <algorithm>
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
// Return true if value represents a zero constant.
|
|
static bool isZeroConstant(Value val) {
|
|
auto constant = val.getDefiningOp<arith::ConstantOp>();
|
|
if (!constant)
|
|
return false;
|
|
|
|
return TypeSwitch<Attribute, bool>(constant.getValue())
|
|
.Case<FloatAttr>(
|
|
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
|
|
.Case<IntegerAttr>(
|
|
[](auto intAttr) { return intAttr.getValue().isZero(); })
|
|
.Default([](auto) { return false; });
|
|
}
|
|
|
|
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
|
|
Operation *op, VectorType vecTy) {
|
|
// Validate only vector as the basic vector store and load ops guarantee
|
|
// XeGPU-compatible memref source.
|
|
unsigned vecRank = vecTy.getRank();
|
|
if (!(vecRank == 1 || vecRank == 2))
|
|
return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
|
|
VectorTransferOpInterface xferOp) {
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp,
|
|
"Masked transfer is not supported");
|
|
|
|
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
|
|
if (!srcTy)
|
|
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
|
|
|
|
// Perform common data transfer checks.
|
|
VectorType vecTy = xferOp.getVectorType();
|
|
if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
|
|
return failure();
|
|
|
|
// Validate further transfer op semantics.
|
|
SmallVector<int64_t> strides;
|
|
int64_t offset;
|
|
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Buffer must be contiguous in the innermost dimension");
|
|
|
|
unsigned vecRank = vecTy.getRank();
|
|
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Boundary check is available only for block instructions.");
|
|
|
|
AffineMap map = xferOp.getPermutationMap();
|
|
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
|
|
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
|
|
unsigned numInputDims = map.getNumInputs();
|
|
for (AffineExpr expr : map.getResults().take_back(vecRank)) {
|
|
auto dim = dyn_cast<AffineDimExpr>(expr);
|
|
if (dim.getPosition() < (numInputDims - vecRank))
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Only the innermost dimensions can be accessed");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static xegpu::CreateNdDescOp
|
|
createNdDescriptor(PatternRewriter &rewriter, Location loc,
|
|
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
|
|
Operation::operand_range offsets) {
|
|
MemRefType srcTy = src.getType();
|
|
auto [strides, offset] = srcTy.getStridesAndOffset();
|
|
|
|
xegpu::CreateNdDescOp ndDesc;
|
|
if (srcTy.hasStaticShape()) {
|
|
ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
|
|
getAsOpFoldResult(offsets));
|
|
} else {
|
|
// In case of any dynamic shapes, source's shape and strides have to be
|
|
// explicitly provided.
|
|
SmallVector<Value> sourceDims;
|
|
unsigned srcRank = srcTy.getRank();
|
|
for (unsigned i = 0; i < srcRank; ++i)
|
|
sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
|
|
|
|
SmallVector<int64_t> constOffsets;
|
|
SmallVector<Value> dynOffsets;
|
|
for (Value offset : offsets) {
|
|
std::optional<int64_t> staticVal = getConstantIntValue(offset);
|
|
if (!staticVal)
|
|
dynOffsets.push_back(offset);
|
|
constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
|
|
}
|
|
|
|
SmallVector<Value> dynShapes;
|
|
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
|
|
if (shape == ShapedType::kDynamic)
|
|
dynShapes.push_back(sourceDims[idx]);
|
|
}
|
|
|
|
// Compute strides in reverse order.
|
|
SmallVector<Value> dynStrides;
|
|
Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
// Last stride is guaranteed to be static and unit.
|
|
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
|
|
accStride =
|
|
rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
|
|
if (strides[i] == ShapedType::kDynamic)
|
|
dynStrides.push_back(accStride);
|
|
}
|
|
std::reverse(dynStrides.begin(), dynStrides.end());
|
|
|
|
ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
|
|
loc, descType, src, dynOffsets, dynShapes, dynStrides,
|
|
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
|
|
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
|
|
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
|
|
}
|
|
|
|
return ndDesc;
|
|
}
|
|
|
|
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
|
|
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = readOp.getLoc();
|
|
|
|
if (failed(transferPreconditions(rewriter, readOp)))
|
|
return failure();
|
|
|
|
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
|
|
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Unsupported non-zero padded out-of-bounds read");
|
|
|
|
AffineMap readMap = readOp.getPermutationMap();
|
|
bool isTransposeLoad = !readMap.isMinorIdentity();
|
|
|
|
VectorType vecTy = readOp.getVectorType();
|
|
Type elementType = vecTy.getElementType();
|
|
unsigned minTransposeBitWidth = 32;
|
|
if (isTransposeLoad &&
|
|
elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Unsupported data type for tranposition");
|
|
|
|
// If load is transposed, get the base shape for the tensor descriptor.
|
|
SmallVector<int64_t> descShape(vecTy.getShape());
|
|
if (isTransposeLoad)
|
|
std::reverse(descShape.begin(), descShape.end());
|
|
auto descType = xegpu::TensorDescType::get(
|
|
descShape, elementType, /*array_length=*/1,
|
|
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
|
|
|
|
xegpu::CreateNdDescOp ndDesc =
|
|
createNdDescriptor(rewriter, loc, descType,
|
|
dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
|
|
readOp.getIndices());
|
|
|
|
DenseI64ArrayAttr transposeAttr =
|
|
!isTransposeLoad ? nullptr
|
|
: DenseI64ArrayAttr::get(rewriter.getContext(),
|
|
ArrayRef<int64_t>{1, 0});
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
auto loadOp = rewriter.create<xegpu::LoadNdOp>(
|
|
loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
|
rewriter.replaceOp(readOp, loadOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TransferWriteLowering
|
|
: public OpRewritePattern<vector::TransferWriteOp> {
|
|
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = writeOp.getLoc();
|
|
|
|
if (failed(transferPreconditions(rewriter, writeOp)))
|
|
return failure();
|
|
|
|
AffineMap map = writeOp.getPermutationMap();
|
|
if (!map.isMinorIdentity())
|
|
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
|
|
|
|
VectorType vecTy = writeOp.getVectorType();
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(),
|
|
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
|
|
xegpu::MemorySpace::Global);
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType,
|
|
dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
|
|
writeOp.getIndices());
|
|
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
auto storeOp =
|
|
rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
|
rewriter.replaceOp(writeOp, storeOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
|
|
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = loadOp.getLoc();
|
|
|
|
VectorType vecTy = loadOp.getResult().getType();
|
|
if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
|
|
return failure();
|
|
|
|
// Boundary check is available only for block instructions.
|
|
bool boundaryCheck = vecTy.getRank() > 1;
|
|
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
|
|
boundaryCheck, xegpu::MemorySpace::Global);
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
|
|
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
|
|
loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
|
rewriter.replaceOp(loadOp, loadNdOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
|
|
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = storeOp.getLoc();
|
|
|
|
TypedValue<VectorType> vector = storeOp.getValueToStore();
|
|
VectorType vecTy = vector.getType();
|
|
if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
|
|
return failure();
|
|
|
|
// Boundary check is available only for block instructions.
|
|
bool boundaryCheck = vecTy.getRank() > 1;
|
|
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(),
|
|
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
|
|
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
auto storeNdOp =
|
|
rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
|
rewriter.replaceOp(storeOp, storeNdOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertVectorToXeGPUPass
|
|
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorToXeGPUConversionPatterns(patterns);
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToXeGPUConversionPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
|
|
StoreLowering>(patterns.getContext());
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
|
|
return std::make_unique<ConvertVectorToXeGPUPass>();
|
|
}
|