### whats the problem mlir-opt could crash while verifying amdgpu.dpp when its operands had vector types, such as ARM SME tile vectors produced by arm_sme.get_tile. The crash occurred during IR verification, before any lowering or passes ran. ### why it happens DPPOp::verify() called Type::getIntOrFloatBitWidth() on the operand type. When the operand was a VectorType, this hit an assertion because only scalar integer and float types have a bitwidth. ### whats the fix Query the bitwidth on the element type using getElementTypeOrSelf() instead of the container type. Add a regression test to ensure amdgpu.dpp verification no longer asserts on vector operand types. Fixes #178128
1236 lines
46 KiB
C++
1236 lines
46 KiB
C++
//===- AMDGPUOps.cpp - MLIR AMDGPU dialect operations ---------------------===//
|
||
//
|
||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
// See https://llvm.org/LICENSE.txt for license information.
|
||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
//
|
||
//===----------------------------------------------------------------------===//
|
||
//
|
||
// This file implements the AMDGPU dialect operations, their verifiers, and
|
||
// their canonicalizations.
|
||
//
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||
#include "mlir/IR/Builders.h"
|
||
#include "mlir/IR/BuiltinTypes.h"
|
||
#include "mlir/IR/Diagnostics.h"
|
||
#include "mlir/IR/Matchers.h"
|
||
#include "mlir/IR/OpImplementation.h"
|
||
#include "mlir/IR/PatternMatch.h"
|
||
#include "mlir/IR/TypeUtilities.h"
|
||
#include "llvm/ADT/DenseMap.h"
|
||
#include "llvm/ADT/SmallVector.h"
|
||
|
||
#include <algorithm>
|
||
#include <cstdint>
|
||
#include <limits>
|
||
#include <optional>
|
||
|
||
using namespace mlir;
|
||
using namespace mlir::amdgpu;
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// 8-bit float ops
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult PackedTrunc2xFp8Op::verify() {
|
||
if (getExisting() && getExisting().getType() != getResult().getType())
|
||
return emitOpError("existing values must have same type as result");
|
||
return success();
|
||
}
|
||
|
||
LogicalResult PackedStochRoundFp8Op::verify() {
|
||
if (getExisting() && getExisting().getType() != getResult().getType())
|
||
return emitOpError("existing values must have same type as result");
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// mxfp float ops
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult PackedScaledTruncOp::verify() {
|
||
if (getExisting() && getExisting().getType() != getResult().getType())
|
||
return emitOpError("existing values must have same type as result");
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// FatRawBufferCastOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
/// Convert the type `source` to one with the same sizes and strides - and
|
||
/// offset, unless `stripOffset` is true, in which case the offset is reset to
|
||
/// 0, if the offset should be reset but the layout of `source` isn't either the
|
||
/// identity layout or a strided layout, this function fails.
|
||
static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
|
||
bool resetOffset) {
|
||
MLIRContext *ctx = source.getContext();
|
||
MemRefType::Builder mb(source);
|
||
mb.setMemorySpace(
|
||
amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
|
||
MemRefLayoutAttrInterface layout = source.getLayout();
|
||
if (resetOffset && !layout.isIdentity()) {
|
||
auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
|
||
if (!stridedLayout)
|
||
return failure();
|
||
MemRefLayoutAttrInterface newLayout =
|
||
StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
|
||
// Special case: if resetting the offset causes the strided layout to become
|
||
// the identity layout, then reset to the identity layout.
|
||
// TODO: this'll get a lot simpler when we have the contiguous layout.
|
||
SmallVector<int64_t> stridesIfIdentity;
|
||
if (source.hasStaticShape()) {
|
||
stridesIfIdentity = computeSuffixProduct(source.getShape());
|
||
} else if (source.getRank() <= 1) {
|
||
stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
|
||
}
|
||
if (stridesIfIdentity == stridedLayout.getStrides()) {
|
||
newLayout = AffineMapAttr::get(
|
||
AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
|
||
}
|
||
mb.setLayout(newLayout);
|
||
}
|
||
return (MemRefType)(mb);
|
||
}
|
||
|
||
LogicalResult FatRawBufferCastOp::inferReturnTypes(
|
||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||
Adaptor adaptor(operands, attributes, properties, regions);
|
||
auto sourceType =
|
||
dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
|
||
if (!sourceType)
|
||
return failure();
|
||
FailureOr<MemRefType> resultType =
|
||
getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
|
||
if (failed(resultType))
|
||
return failure();
|
||
inferredReturnTypes = SmallVector<Type>{*resultType};
|
||
return success();
|
||
}
|
||
|
||
FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
|
||
int resultIndex,
|
||
int dim) {
|
||
assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
|
||
return memref::getMixedSize(builder, getLoc(), getSource(), dim);
|
||
}
|
||
|
||
LogicalResult FatRawBufferCastOp::verify() {
|
||
FailureOr<MemRefType> expectedResultType =
|
||
getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
|
||
if (failed(expectedResultType))
|
||
return emitOpError("source type ")
|
||
<< getSource().getType() << " can't have its offset reset";
|
||
if (getResult().getType() != *expectedResultType)
|
||
return emitOpError("expected result type to be ")
|
||
<< *expectedResultType << " but got " << getResult().getType();
|
||
return success();
|
||
}
|
||
|
||
static bool hasGlobalMemorySpace(Attribute memorySpace) {
|
||
if (!memorySpace)
|
||
return true;
|
||
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
|
||
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
|
||
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
|
||
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
|
||
return false;
|
||
}
|
||
|
||
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
|
||
if (!memorySpace)
|
||
return false;
|
||
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
|
||
return intMemorySpace.getInt() == 3;
|
||
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
|
||
return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
|
||
return false;
|
||
}
|
||
|
||
static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
|
||
if (!memorySpace)
|
||
return false;
|
||
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
|
||
return intMemorySpace.getInt() == 7;
|
||
if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
|
||
return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
|
||
return false;
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// RawBuffer*Op
|
||
//===----------------------------------------------------------------------===//
|
||
template <typename T>
|
||
static LogicalResult verifyRawBufferOp(T &op) {
|
||
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
|
||
bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
|
||
|
||
if (!isGlobal)
|
||
return op.emitOpError(
|
||
"Buffer ops must operate on a memref in global memory");
|
||
if (!bufferType.hasRank())
|
||
return op.emitOpError(
|
||
"Cannot meaningfully buffer_store to an unranked memref");
|
||
if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
|
||
return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
|
||
" indices to memref");
|
||
return success();
|
||
}
|
||
|
||
LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
|
||
|
||
LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
|
||
|
||
LogicalResult RawBufferAtomicFaddOp::verify() {
|
||
return verifyRawBufferOp(*this);
|
||
}
|
||
|
||
LogicalResult RawBufferAtomicFmaxOp::verify() {
|
||
return verifyRawBufferOp(*this);
|
||
}
|
||
|
||
LogicalResult RawBufferAtomicSmaxOp::verify() {
|
||
return verifyRawBufferOp(*this);
|
||
}
|
||
|
||
LogicalResult RawBufferAtomicUminOp::verify() {
|
||
return verifyRawBufferOp(*this);
|
||
}
|
||
|
||
LogicalResult RawBufferAtomicCmpswapOp::verify() {
|
||
return verifyRawBufferOp(*this);
|
||
}
|
||
|
||
static std::optional<uint32_t> getConstantUint32(Value v) {
|
||
APInt cst;
|
||
if (!v.getType().isInteger(32))
|
||
return std::nullopt;
|
||
if (matchPattern(v, m_ConstantInt(&cst)))
|
||
return cst.getZExtValue();
|
||
return std::nullopt;
|
||
}
|
||
|
||
template <typename OpType>
|
||
static bool staticallyOutOfBounds(OpType op) {
|
||
if (!op.getBoundsCheck())
|
||
return false;
|
||
MemRefType bufferType = op.getMemref().getType();
|
||
if (!bufferType.hasStaticShape())
|
||
return false;
|
||
int64_t offset;
|
||
SmallVector<int64_t> strides;
|
||
if (failed(bufferType.getStridesAndOffset(strides, offset)))
|
||
return false;
|
||
int64_t result = offset + op.getIndexOffset().value_or(0);
|
||
if (op.getSgprOffset()) {
|
||
std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
|
||
if (!sgprOffset)
|
||
return false;
|
||
result += *sgprOffset;
|
||
}
|
||
if (strides.size() != op.getIndices().size())
|
||
return false;
|
||
int64_t indexVal = 0;
|
||
for (auto pair : llvm::zip(strides, op.getIndices())) {
|
||
int64_t stride = std::get<0>(pair);
|
||
Value idx = std::get<1>(pair);
|
||
std::optional<uint32_t> idxVal = getConstantUint32(idx);
|
||
if (!idxVal)
|
||
return false;
|
||
indexVal += stride * *idxVal;
|
||
}
|
||
result += indexVal;
|
||
if (result > std::numeric_limits<uint32_t>::max())
|
||
// Overflow means don't drop
|
||
return false;
|
||
return result >= bufferType.getNumElements();
|
||
}
|
||
|
||
namespace {
|
||
template <typename OpType>
|
||
struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
|
||
using OpRewritePattern<OpType>::OpRewritePattern;
|
||
|
||
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
|
||
if (!staticallyOutOfBounds(op))
|
||
return failure();
|
||
Type loadType = op.getResult().getType();
|
||
rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
|
||
rw.getZeroAttr(loadType));
|
||
return success();
|
||
}
|
||
};
|
||
|
||
template <typename OpType>
|
||
struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
|
||
using OpRewritePattern<OpType>::OpRewritePattern;
|
||
|
||
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
|
||
if (!staticallyOutOfBounds(op))
|
||
return failure();
|
||
|
||
rw.eraseOp(op);
|
||
return success();
|
||
}
|
||
};
|
||
} // end namespace
|
||
|
||
void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||
MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
|
||
}
|
||
|
||
void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||
MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
|
||
}
|
||
|
||
void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
|
||
}
|
||
|
||
void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
|
||
}
|
||
|
||
void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
|
||
}
|
||
|
||
void RawBufferAtomicUminOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
|
||
}
|
||
|
||
void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
|
||
context);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// ScaledExtPackedMatrixOp
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult ScaledExtPackedMatrixOp::verify() {
|
||
int blockSize = getBlockSize();
|
||
assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
|
||
|
||
int firstScaleByte = getFirstScaleByte();
|
||
int firstScaleLane = getFirstScaleLane();
|
||
auto sourceType = cast<VectorType>(getSource().getType());
|
||
Type elementType = sourceType.getElementType();
|
||
auto floatType = cast<FloatType>(elementType);
|
||
unsigned bitWidth = floatType.getWidth();
|
||
|
||
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
|
||
|
||
const bool is_fp8 = bitWidth == 8;
|
||
const bool is_block_16 = blockSize == 16;
|
||
|
||
if (!is_fp8) {
|
||
if (is_block_16) {
|
||
if (!llvm::is_contained({0, 1}, firstScaleByte)) {
|
||
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
|
||
"or 1 for f4 and f6.");
|
||
}
|
||
} else {
|
||
if (!llvm::is_contained({0, 2}, firstScaleByte)) {
|
||
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
|
||
"or 2 for f4 and f6.");
|
||
}
|
||
}
|
||
} else {
|
||
if (is_block_16) {
|
||
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
|
||
((firstScaleLane == 16) && (firstScaleByte == 2));
|
||
if (!is_valid) {
|
||
return emitOpError("blockSize of 16 can only have (firstScaleLane, "
|
||
"firstScaleByte) be (0, 0) or (16, 2) for f8.");
|
||
}
|
||
}
|
||
}
|
||
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// WMMAOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
|
||
IntegerAttr &m, IntegerAttr &n,
|
||
IntegerAttr &k) {
|
||
SmallVector<int64_t, 3> dimensions;
|
||
if (parser.parseDimensionList(dimensions, false, false))
|
||
return failure();
|
||
if (dimensions.size() != 3)
|
||
return parser.emitError(parser.getCurrentLocation())
|
||
<< "expected 3 dimensions in MNK dimension list";
|
||
|
||
m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
|
||
n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
|
||
k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
|
||
return success();
|
||
}
|
||
|
||
LogicalResult WMMAOp::verify() {
|
||
auto sourceAType = cast<VectorType>(getSourceA().getType());
|
||
auto sourceBType = cast<VectorType>(getSourceB().getType());
|
||
auto destType = cast<VectorType>(getDestC().getType());
|
||
|
||
Type sourceAElemType = sourceAType.getElementType();
|
||
Type sourceBElemType = sourceBType.getElementType();
|
||
if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
|
||
return emitOpError("source vectors have different lengths: ")
|
||
<< sourceAType << " vs. " << sourceBType;
|
||
}
|
||
|
||
bool isDestFloat = destType.getElementType().isFloat();
|
||
bool isSrcFloat = sourceAElemType.isFloat();
|
||
|
||
if (isDestFloat && !isSrcFloat)
|
||
return emitOpError("expected float sources with float destination");
|
||
if (!isDestFloat && isSrcFloat)
|
||
return emitOpError("expected int sources with int destination");
|
||
|
||
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
|
||
return emitOpError(
|
||
"source element types must match (except for fp8/bf8) but have ")
|
||
<< sourceAType << " and " << sourceBType;
|
||
}
|
||
|
||
if (isSrcFloat) {
|
||
if (getClamp())
|
||
return emitOpError("clamp flag is not supported for float types");
|
||
if (getUnsignedA() || getUnsignedB())
|
||
return emitOpError("unsigned flags are not supported for float types");
|
||
}
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// ScaledWMMAOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult ScaledWMMAOp::verify() {
|
||
// Helper functions for type classification.
|
||
auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
|
||
auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
|
||
auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
|
||
auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
|
||
auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
|
||
auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
|
||
|
||
auto sourceAType = cast<VectorType>(getSourceA().getType());
|
||
auto sourceBType = cast<VectorType>(getSourceB().getType());
|
||
auto destType = cast<VectorType>(getDestC().getType());
|
||
|
||
// Validate source element types are small floats (fp4/fp6/fp8).
|
||
Type aElemType = sourceAType.getElementType();
|
||
Type bElemType = sourceBType.getElementType();
|
||
|
||
// Validate vector lengths based on dimensions.
|
||
int64_t m = getM();
|
||
int64_t aLen = sourceAType.getNumElements();
|
||
int64_t bLen = sourceBType.getNumElements();
|
||
int64_t expectedOutLen = (m == 16) ? 8 : 16;
|
||
|
||
if (destType.getNumElements() != expectedOutLen)
|
||
return emitOpError("expected output vector of length ")
|
||
<< expectedOutLen << " but got " << destType.getNumElements();
|
||
|
||
if (m == 16) {
|
||
// For 16×16×128: both A and B must be 64 elements.
|
||
if (aLen != 64)
|
||
return emitOpError(
|
||
"for 16x16x128, sourceA must have 64 elements but got ")
|
||
<< aLen;
|
||
if (bLen != 64)
|
||
return emitOpError(
|
||
"for 16x16x128, sourceB must have 64 elements but got ")
|
||
<< bLen;
|
||
} else { // m == 32
|
||
// For 32×16×128: only fp4 is supported, A is 128, B is 64.
|
||
if (!isF4(aElemType) && !isF4(bElemType))
|
||
return emitOpError("32x16x128 only supports fp4 element types");
|
||
|
||
if (aLen != 128)
|
||
return emitOpError(
|
||
"for 32x16x128, sourceA must have 128 elements but got ")
|
||
<< aLen;
|
||
if (bLen != 64)
|
||
return emitOpError(
|
||
"for 32x16x128, sourceB must have 64 elements but got ")
|
||
<< bLen;
|
||
|
||
// For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be
|
||
// 0.
|
||
if (getAFirstScaleLane() != 0)
|
||
return emitOpError("for 32x16x128, a_first_scale_lane must be 0");
|
||
}
|
||
|
||
// Validate scale types and their compatibility with matrix element types.
|
||
auto scaleAType = cast<VectorType>(getScaleA().getType());
|
||
auto scaleBType = cast<VectorType>(getScaleB().getType());
|
||
Type scaleAElemType = scaleAType.getElementType();
|
||
Type scaleBElemType = scaleBType.getElementType();
|
||
|
||
// Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
|
||
if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
|
||
return emitOpError(
|
||
"scale operands must have f8 element types (E8M0FNU or E4M3FN)");
|
||
|
||
// Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
|
||
if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
|
||
return success();
|
||
|
||
// Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
|
||
if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
|
||
isF4(bElemType) && isE4M3(scaleBElemType))
|
||
return success();
|
||
|
||
// Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
|
||
if (isF4(aElemType) && isE4M3(scaleAElemType) &&
|
||
(isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
|
||
return success();
|
||
|
||
// Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
|
||
if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
|
||
isE4M3(scaleBElemType))
|
||
return success();
|
||
|
||
// No valid combination matched.
|
||
return emitOpError("invalid combination of matrix and scale types: ")
|
||
<< "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
|
||
<< ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MFMAOp
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult MFMAOp::verify() {
|
||
constexpr uint32_t waveSize = 64;
|
||
Builder b(getContext());
|
||
|
||
Type sourceType = getSourceA().getType();
|
||
Type destType = getDestC().getType();
|
||
|
||
Type sourceElem = sourceType, destElem = destType;
|
||
uint32_t sourceLen = 1, destLen = 1;
|
||
if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
|
||
sourceLen = sourceVector.getNumElements();
|
||
sourceElem = sourceVector.getElementType();
|
||
}
|
||
if (auto destVector = dyn_cast<VectorType>(destType)) {
|
||
destLen = destVector.getNumElements();
|
||
destElem = destVector.getElementType();
|
||
}
|
||
|
||
Type sourceBType = getSourceB().getType();
|
||
if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
|
||
int64_t sourceBLen = 1;
|
||
Type sourceBElem = sourceBType;
|
||
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
|
||
sourceBLen = sourceBVector.getNumElements();
|
||
sourceBElem = sourceBVector.getElementType();
|
||
}
|
||
if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
|
||
!sourceBElem.isFloat(4))
|
||
return emitOpError("expected both source operands to have small-float "
|
||
"elements if one does");
|
||
if (sourceLen != sourceBLen)
|
||
return emitOpError(
|
||
"expected both small-float source vectors to have the same length");
|
||
} else {
|
||
if (sourceType != sourceBType)
|
||
return emitOpError("expected both non-small-float source operand types "
|
||
"to match exactly");
|
||
}
|
||
// Normalize the wider integer types the compiler expects to i8.
|
||
if (sourceElem.isInteger(32)) {
|
||
sourceLen *= 4;
|
||
sourceElem = b.getI8Type();
|
||
}
|
||
if (sourceElem.isInteger(64)) {
|
||
sourceLen *= 8;
|
||
sourceElem = b.getI8Type();
|
||
}
|
||
|
||
int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
|
||
if (sourceLen != numSourceElems)
|
||
return emitOpError("expected " + Twine(numSourceElems) +
|
||
" source values for this operation but got " +
|
||
Twine(sourceLen));
|
||
|
||
int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
|
||
if (destLen != numDestElems)
|
||
return emitOpError("expected " + Twine(numDestElems) +
|
||
" result values for this operation but got " +
|
||
Twine(destLen));
|
||
|
||
if (destElem.isF64() && getBlgp() != MFMAPermB::none)
|
||
return emitOpError(
|
||
"double-precision ops do not support permuting lanes of B");
|
||
if (destElem.isF64() && getCbsz() != 0)
|
||
return emitOpError(
|
||
"double-precision ops do not support permuting lanes of A");
|
||
if (getAbid() >= (1u << getCbsz()))
|
||
return emitOpError(
|
||
"block ID for permuting A (abid) must be below 2 ** cbsz");
|
||
|
||
if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
|
||
return emitOpError(
|
||
"negation flags only available for double-precision operations");
|
||
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// SparseMFMAOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult SparseMFMAOp::verify() {
|
||
constexpr uint32_t waveSize = 64;
|
||
|
||
auto sparseType = cast<VectorType>(getSourceA().getType());
|
||
auto denseType = cast<VectorType>(getSourceB().getType());
|
||
auto destType = cast<VectorType>(getDestC().getType());
|
||
|
||
Type sparseElem = sparseType.getElementType();
|
||
Type denseElem = denseType.getElementType();
|
||
int64_t sparseLen = sparseType.getNumElements();
|
||
int64_t denseLen = denseType.getNumElements();
|
||
int64_t destLen = destType.getNumElements();
|
||
|
||
if (denseLen != 2 * sparseLen)
|
||
return emitOpError("expected dense source operand to have exactly double "
|
||
"the number of elements of the sparse source operand");
|
||
|
||
// Check that source element types are compatible.
|
||
// For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
|
||
// For other types, element types must match exactly.
|
||
bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
|
||
if (!bothFloat8 && sparseElem != denseElem)
|
||
return emitOpError(
|
||
"expected source operands to have the same element type");
|
||
|
||
// When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
|
||
// When CBSZ != 0, the first index set is always used (ABID ignored).
|
||
bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
|
||
// 8-bit source: ABID selects one of two 16-bit index sets.
|
||
if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
|
||
return emitOpError("ABID must be 0 or 1 for 8-bit source data");
|
||
// 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
|
||
if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
|
||
return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
|
||
|
||
// Validate sparseIdx type matches source element type.
|
||
auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
|
||
if (is8BitSource) {
|
||
// 8-bit source data requires vector<2xi16> sparse indices.
|
||
if (sparseIdxType.getNumElements() != 2 ||
|
||
!sparseIdxType.getElementType().isInteger(16))
|
||
return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
|
||
"source data, but got ")
|
||
<< getSparseIdx().getType();
|
||
} else {
|
||
// 16-bit source data requires vector<4xi8> sparse indices.
|
||
if (sparseIdxType.getNumElements() != 4 ||
|
||
!sparseIdxType.getElementType().isInteger(8))
|
||
return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
|
||
"source data, but got ")
|
||
<< getSparseIdx().getType();
|
||
}
|
||
|
||
int64_t expectedSourceElems = (getM() * getK()) / waveSize;
|
||
if (denseLen != expectedSourceElems)
|
||
return emitOpError("expected " + Twine(expectedSourceElems) +
|
||
" source values for this operation but got " +
|
||
Twine(denseLen));
|
||
|
||
int64_t expectedDestElems = (getM() * getN()) / waveSize;
|
||
if (destLen != expectedDestElems)
|
||
return emitOpError("expected " + Twine(expectedDestElems) +
|
||
" result values for this operation but got " +
|
||
Twine(destLen));
|
||
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// DPPOp
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult DPPOp::verify() {
|
||
DPPPerm kind = getKind();
|
||
Attribute permArgument = getPermArgument().value_or(Attribute{});
|
||
|
||
switch (kind) {
|
||
|
||
case DPPPerm::quad_perm: {
|
||
auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
|
||
if (!quadPermAttr || quadPermAttr.size() != 4) {
|
||
return emitOpError("quad_perm attribute must have exactly 4 elements");
|
||
}
|
||
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
|
||
int32_t num = elem.getInt();
|
||
if (num < 0 || num > 3) {
|
||
return emitOpError(
|
||
"Each element of quad_perm must be in the range [0, 3]");
|
||
}
|
||
}
|
||
} break;
|
||
|
||
case DPPPerm::row_shl:
|
||
case DPPPerm::row_shr:
|
||
case DPPPerm::row_ror: {
|
||
if (!permArgument) {
|
||
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
|
||
"' value not specified");
|
||
}
|
||
if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
|
||
uint32_t attrValue = intAttr.getInt();
|
||
if (attrValue < 1 || attrValue > 15) {
|
||
return emitOpError("Attribute value must be between 1 and 15");
|
||
}
|
||
}
|
||
} break;
|
||
|
||
case DPPPerm::wave_shl:
|
||
case DPPPerm::wave_shr:
|
||
case DPPPerm::wave_rol:
|
||
case DPPPerm::wave_ror:
|
||
case DPPPerm::row_mirror:
|
||
case DPPPerm::row_half_mirror:
|
||
case DPPPerm::row_bcast_15:
|
||
case DPPPerm::row_bcast_31: {
|
||
if (permArgument && !isa<UnitAttr>(permArgument)) {
|
||
return emitOpError("Expected unit attribute for permArgument, but found "
|
||
"non-trivial argument");
|
||
}
|
||
break;
|
||
}
|
||
}
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// PermlaneSwapOp
|
||
//===----------------------------------------------------------------------===//
|
||
LogicalResult PermlaneSwapOp::verify() {
|
||
unsigned rowLength = getRowLength();
|
||
|
||
if (rowLength != 16 && rowLength != 32)
|
||
return emitOpError("row_length attribute must either be 16 or 32.");
|
||
|
||
return success();
|
||
}
|
||
|
||
/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
|
||
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op,
|
||
PatternRewriter &rewriter) {
|
||
if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
|
||
rewriter.eraseOp(op);
|
||
return success();
|
||
}
|
||
return failure();
|
||
}
|
||
|
||
void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||
MLIRContext *context) {
|
||
results.add(eraseRedundantLDSBarrierOps);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MemoryCounterWaitOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
namespace {
|
||
/// Fuse adjacent memory counter wait ops, taking the minimum value of the
|
||
/// counters.
|
||
struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
|
||
using Base::Base;
|
||
|
||
LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
|
||
if (!next)
|
||
return failure();
|
||
|
||
auto setters = {&MemoryCounterWaitOp::setLoad,
|
||
&MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
|
||
&MemoryCounterWaitOp::setExp,
|
||
&MemoryCounterWaitOp::setTensor};
|
||
auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
|
||
op.getTensor()};
|
||
auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
|
||
next.getExp(), next.getTensor()};
|
||
rewriter.modifyOpInPlace(op, [&] {
|
||
for (auto [setter, lhs, rhs] :
|
||
llvm::zip_equal(setters, lhsVals, rhsVals)) {
|
||
if (lhs && rhs) {
|
||
(op.*setter)(std::min(*lhs, *rhs));
|
||
} else if (lhs) {
|
||
(op.*setter)(*lhs);
|
||
} else if (rhs) {
|
||
(op.*setter)(*rhs);
|
||
}
|
||
}
|
||
});
|
||
rewriter.eraseOp(next);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
void MemoryCounterWaitOp::getCanonicalizationPatterns(
|
||
RewritePatternSet &results, MLIRContext *context) {
|
||
results.add<FuseMemoryCounterWaitOp>(context);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// GatherToLDSOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult GatherToLDSOp::verify() {
|
||
MemRefType srcType = cast<MemRefType>(getSrc().getType());
|
||
MemRefType dstType = cast<MemRefType>(getDst().getType());
|
||
|
||
if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
|
||
return emitOpError("destination type inner most dim must be contiguous");
|
||
|
||
auto elemType = srcType.getElementType();
|
||
// Check $src and $dst element types are the same.
|
||
if (elemType != dstType.getElementType())
|
||
return emitOpError("source and destination element types must match");
|
||
|
||
// copy type sizes should be 1, 2, 4, 12 or 16 bytes.
|
||
auto transferType = getTransferType();
|
||
int transferSize;
|
||
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
|
||
transferSize = vectorTransfer.getNumElements() *
|
||
vectorTransfer.getElementTypeBitWidth();
|
||
} else {
|
||
transferSize = transferType.getIntOrFloatBitWidth();
|
||
}
|
||
if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
|
||
return emitOpError(
|
||
"Transfering type size must be 8, 16, 32, 96 or 128 bits");
|
||
|
||
if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
|
||
!hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
|
||
return emitOpError(
|
||
"source memory address space must be global or fat raw buffer");
|
||
|
||
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
|
||
return emitOpError("destination memory address space must be Workgroup");
|
||
|
||
return success();
|
||
}
|
||
|
||
namespace {
|
||
/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
|
||
/// information or changes layout, the cast can be skipped.
|
||
struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
|
||
using OpRewritePattern::OpRewritePattern;
|
||
|
||
LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
|
||
PatternRewriter &rewriter) const override {
|
||
bool modified = false;
|
||
auto foldCast = [&](OpOperand &operand) {
|
||
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
|
||
if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
||
rewriter.modifyOpInPlace(gatherOp,
|
||
[&] { operand.assign(castOp.getSource()); });
|
||
modified = true;
|
||
}
|
||
}
|
||
};
|
||
|
||
foldCast(gatherOp.getSrcMutable());
|
||
foldCast(gatherOp.getDstMutable());
|
||
|
||
return success(modified);
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||
MLIRContext *context) {
|
||
results.add<FoldGatherToLDSOfCast>(context);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// TransposeLoadOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult TransposeLoadOp::verify() {
|
||
MemRefType srcType = cast<MemRefType>(getSrc().getType());
|
||
|
||
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
|
||
return emitOpError("source memory address space must be Workgroup");
|
||
|
||
auto transferType = cast<VectorType>(getType());
|
||
size_t numElements = transferType.getNumElements();
|
||
size_t elementTypeSize =
|
||
transferType.getElementType().getIntOrFloatBitWidth();
|
||
|
||
// ElementSize -> NumElements
|
||
const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
|
||
{4, 16},
|
||
{6, 16},
|
||
{8, 8},
|
||
{16, 4},
|
||
};
|
||
|
||
auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
|
||
if (validNumElems == kValidLoadSizeMap.end())
|
||
return emitOpError("Unsupported element type size for transpose load: ")
|
||
<< elementTypeSize << " bits";
|
||
|
||
if (numElements != validNumElems->second)
|
||
return emitOpError(
|
||
"Transferring type size mismatch: expected num of elements: ")
|
||
<< validNumElems->second;
|
||
|
||
return success();
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MakeDmaBaseOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
template <typename BaseOp>
|
||
static LogicalResult verifyBase(BaseOp op) {
|
||
auto ldsType = cast<MemRefType>(op.getLds().getType());
|
||
auto globalType = cast<MemRefType>(op.getGlobal().getType());
|
||
if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
|
||
return op.emitOpError(
|
||
"lds memref must have workgroup address space attribute.");
|
||
if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
|
||
return op.emitOpError(
|
||
"global memref must have global address space attribute.");
|
||
|
||
Type elementType = ldsType.getElementType();
|
||
unsigned width = elementType.getIntOrFloatBitWidth();
|
||
|
||
if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
|
||
return op.emitOpError(
|
||
"element type must be 1, 2, 4, or 8 bytes long but type was ")
|
||
<< width << " bits long.";
|
||
return success();
|
||
}
|
||
|
||
LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MakeGatherDmaBaseOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult
|
||
TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
Type elementType, Type indexType) {
|
||
unsigned width = elementType.getIntOrFloatBitWidth();
|
||
if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
|
||
return emitError()
|
||
<< "element type must be 1, 2, 4, or 8 bytes wide but type "
|
||
<< elementType << " is " << width / 8 << " bytes wide.";
|
||
MLIRContext *ctx = elementType.getContext();
|
||
Type i16 = IntegerType::get(ctx, 32);
|
||
Type i32 = IntegerType::get(ctx, 16);
|
||
if (!llvm::is_contained({i16, i32}, indexType))
|
||
return emitError() << "index type must be i16 or i32 but index type is "
|
||
<< indexType << ".";
|
||
return success();
|
||
}
|
||
|
||
LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MakeDmaDescriptorOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
template <typename DescriptorOp>
|
||
static LogicalResult verifyDescriptorOp(DescriptorOp op) {
|
||
ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
|
||
|
||
if (globalStaticStrides.empty())
|
||
return op.emitOpError("strides must not be empty.");
|
||
if (globalStaticStrides.back() != 1)
|
||
return op.emitOpError("strides for the innermost dimension must be 1.");
|
||
|
||
ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
|
||
size_t rank = globalStaticSizes.size();
|
||
if (rank > 5)
|
||
return op.emitOpError("tensor and tile must be at most of rank 5.");
|
||
if (rank != globalStaticStrides.size())
|
||
return op.emitOpError("strides and sizes must have same rank.");
|
||
|
||
ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
|
||
if (rank != sharedStaticSizes.size())
|
||
return op.emitOpError("tensor must have same rank as tile.");
|
||
|
||
unsigned elementTypeWidth = op.getElementTypeWidth();
|
||
if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
|
||
return op.emitOpError(
|
||
"element type width must be 1, 2, 4 or 8 bytes, but was ")
|
||
<< elementTypeWidth << " bits long";
|
||
|
||
if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
|
||
auto atomicBarrierAddressType =
|
||
cast<MemRefType>(atomicBarrierAddress.getType());
|
||
bool barrierInLDS =
|
||
hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
|
||
if (!barrierInLDS)
|
||
return op.emitOpError("atomic barrier address must be in LDS.");
|
||
}
|
||
|
||
if (op.getEarlyTimeout() && !op.getWorkgroupMask())
|
||
return op.emitOpError(
|
||
"early timeout does not apply when workgroup_mask is not set.");
|
||
return success();
|
||
}
|
||
|
||
template <typename DescriptorOp, typename FoldAdaptor>
|
||
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
|
||
SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
|
||
SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
|
||
SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
|
||
|
||
if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
|
||
/*onlyNonZero=*/true)) &&
|
||
failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
|
||
/*onlyNonZero=*/true)) &&
|
||
failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
|
||
/*onlyNonZero=*/true)))
|
||
return nullptr;
|
||
|
||
SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
|
||
dynamicSharedSizes;
|
||
SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
|
||
staticSharedSizes;
|
||
|
||
dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
|
||
staticGlobalSizes);
|
||
op.setGlobalStaticSizes(staticGlobalSizes);
|
||
op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
|
||
|
||
dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
|
||
staticGlobalStrides);
|
||
op.setGlobalStaticStrides(staticGlobalStrides);
|
||
op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
|
||
|
||
dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
|
||
staticSharedSizes);
|
||
op.setSharedStaticSizes(staticSharedSizes);
|
||
op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
|
||
return op.getResult();
|
||
}
|
||
|
||
LogicalResult MakeDmaDescriptorOp::verify() {
|
||
return verifyDescriptorOp(*this);
|
||
}
|
||
|
||
OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
|
||
return foldDescriptorOp(*this, adaptor);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// MakeGatherDmaDescriptorOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
LogicalResult MakeGatherDmaDescriptorOp::verify() {
|
||
ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
|
||
size_t rank = globalStaticSizes.size();
|
||
if (rank > 2)
|
||
return emitOpError(
|
||
"tensor and tile must be at most of rank two in gather mode.");
|
||
Value indices = getIndices();
|
||
Type elementType = cast<VectorType>(indices.getType()).getElementType();
|
||
if (elementType != getBase().getType().getIndexType())
|
||
return emitOpError("indices' element type must match base's element type.");
|
||
|
||
return verifyDescriptorOp(*this);
|
||
}
|
||
|
||
OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
|
||
return foldDescriptorOp(*this, adaptor);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// ScaledMFMAOp
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
namespace {
|
||
/// Check if the scales input is used in other scaled mfma's while they exist.
|
||
/// If theyre unused then pack the scales.
|
||
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
|
||
using OpRewritePattern::OpRewritePattern;
|
||
|
||
LogicalResult matchAndRewrite(ScaledMFMAOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
auto setOpsel = [&op](unsigned idx, int64_t val) {
|
||
switch (idx) {
|
||
case 3:
|
||
op.setScalesIdxA(val);
|
||
break;
|
||
case 4:
|
||
op.setScalesIdxB(val);
|
||
break;
|
||
default:
|
||
break;
|
||
}
|
||
};
|
||
|
||
// For every scale operand of this ScaledMFMAOp, if the scale is produced by
|
||
// the extraction of a single scale from some vector, then attempt to
|
||
// extract 4 values from that vector instead.
|
||
//
|
||
// Example: (f8 here means f8E8M0FNU)
|
||
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
|
||
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
|
||
// amdgpu.scaled_mfma(%scale[0] * ...
|
||
//
|
||
// rewrite to:
|
||
//
|
||
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
|
||
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
|
||
// amdgpu.scaled_mfma(%scale[0-3] * ...
|
||
//
|
||
// This creates duplicate shape_casts for every use but these will be
|
||
// removed in CSE.
|
||
for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
|
||
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
|
||
if (!insertOp) {
|
||
return rewriter.notifyMatchFailure(op,
|
||
"defining op not a vector.insert");
|
||
}
|
||
// If the extracted value is not a single scalar, then it has been packed.
|
||
if (isa<VectorType>(insertOp.getValueToStore().getType())) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "scaled mfma operand already packed");
|
||
}
|
||
|
||
auto extractOp =
|
||
insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
|
||
if (!extractOp) {
|
||
return rewriter.notifyMatchFailure(op,
|
||
"defining op not a vector.extract");
|
||
}
|
||
|
||
Value scaleSrc = extractOp.getOperand(0);
|
||
auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
|
||
if (!scaleSrcType) {
|
||
return rewriter.notifyMatchFailure(op, "not a vector type");
|
||
}
|
||
|
||
// We do not handle dynamic dims yet, assume that the input is padded to
|
||
// a static shape now.
|
||
if (!scaleSrcType.hasStaticShape()) {
|
||
return rewriter.notifyMatchFailure(op,
|
||
"dynamic dims not yet supported");
|
||
}
|
||
|
||
int64_t numElements = scaleSrcType.getNumElements();
|
||
if (numElements <= 4) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "no packing if # of scales less than four");
|
||
}
|
||
|
||
// Find a linearized idx using the size and offsets of the extract op.
|
||
auto extractedPos = llvm::to_vector_of<int64_t>(
|
||
llvm::reverse(extractOp.getStaticPosition()));
|
||
ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
|
||
int64_t scaleSrcRank = scaleSrcType.getRank();
|
||
SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
|
||
for (int64_t i = 1; i < scaleSrcRank; ++i) {
|
||
extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
|
||
}
|
||
int64_t idx = linearize(extractedPos, extractSizes);
|
||
|
||
// All n scales (where n is the total number of scales) must now be
|
||
// extracted in chunks of 4 elements. This is done by dividing the
|
||
// original vector of scales into groups of 4 elements
|
||
// at offsets 0, 4, ..., m (where m = n/4). All extractions of a
|
||
// scale at a particular index are now replaced with an extraction
|
||
// of the entire group of 4 elements to which that index belongs.
|
||
//
|
||
// If the number of scales happens to be indivisible by 4, extract
|
||
// the remaining n - m scales in a chunk of 4 elements starting at
|
||
// offset n - 4.
|
||
int64_t offset = idx - (idx % 4);
|
||
int64_t opsel = idx - offset;
|
||
int64_t size = 4l;
|
||
// Accomdate remaining elements in the case of non-4-divisible vectors.
|
||
if (numElements - offset < size) {
|
||
opsel = size - (numElements - idx);
|
||
offset = numElements - 4l;
|
||
}
|
||
Type scaleSrcElemType = scaleSrcType.getElementType();
|
||
auto newSrcType =
|
||
VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
|
||
Value newScaleSrc =
|
||
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
|
||
auto extract = vector::ExtractStridedSliceOp::create(
|
||
rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
|
||
ArrayRef{int64_t(1)});
|
||
rewriter.modifyOpInPlace(op, [&] {
|
||
op->setOperand(opIdx, extract);
|
||
setOpsel(opIdx, opsel);
|
||
});
|
||
}
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||
MLIRContext *context) {
|
||
results.add<PackScales>(context);
|
||
}
|
||
|
||
//===----------------------------------------------------------------------===//
|
||
// In-LDS Barrier Operations (gfx1250+)
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
template <typename T>
|
||
static LogicalResult verifyDsBarrierOpCommon(T &op) {
|
||
MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
|
||
if (!hasWorkgroupMemorySpace(memrefType.getMemorySpace()))
|
||
return op.emitOpError("barrier must be in workgroup (LDS) memory");
|
||
|
||
return success();
|
||
}
|
||
|
||
LogicalResult DsBarrierInitOp::verify() {
|
||
return verifyDsBarrierOpCommon(*this);
|
||
}
|
||
|
||
LogicalResult DsBarrierPollStateOp::verify() {
|
||
return verifyDsBarrierOpCommon(*this);
|
||
}
|
||
|
||
LogicalResult DsAsyncBarrierArriveOp::verify() {
|
||
return verifyDsBarrierOpCommon(*this);
|
||
}
|
||
|
||
LogicalResult DsBarrierArriveOp::verify() {
|
||
return verifyDsBarrierOpCommon(*this);
|
||
}
|
||
|
||
#define GET_OP_CLASSES
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
|