3081 lines
129 KiB
C++
3081 lines
129 KiB
C++
//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
|
|
|
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
|
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "../LLVMCommon/MemRefDescriptor.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::amdgpu;
|
|
|
|
// Define commonly used chipsets versions for convenience.
|
|
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
|
|
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
|
|
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
|
|
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
|
|
constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
|
|
|
|
/// Convert an unsigned number `val` to i32.
|
|
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value val) {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
// Force check that `val` is of int type.
|
|
auto valTy = cast<IntegerType>(val.getType());
|
|
if (i32 == valTy)
|
|
return val;
|
|
return valTy.getWidth() > 32
|
|
? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
|
|
: Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
|
|
}
|
|
|
|
static Value createI32Constant(ConversionPatternRewriter &rewriter,
|
|
Location loc, int32_t value) {
|
|
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
|
|
}
|
|
|
|
/// Convert an unsigned number `val` to i64.
|
|
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value val) {
|
|
IntegerType i64 = rewriter.getI64Type();
|
|
// Force check that `val` is of int type.
|
|
auto valTy = cast<IntegerType>(val.getType());
|
|
if (i64 == valTy)
|
|
return val;
|
|
return valTy.getWidth() > 64
|
|
? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
|
|
: Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
|
|
}
|
|
|
|
static Value createI64Constant(ConversionPatternRewriter &rewriter,
|
|
Location loc, int64_t value) {
|
|
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
|
|
}
|
|
|
|
/// Returns the linear index used to access an element in the memref.
|
|
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
|
|
Location loc, MemRefDescriptor &memRefDescriptor,
|
|
ValueRange indices, ArrayRef<int64_t> strides) {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Value index;
|
|
for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
|
|
if (stride != 1) { // Skip if stride is 1.
|
|
Value strideValue =
|
|
ShapedType::isDynamic(stride)
|
|
? convertUnsignedToI32(rewriter, loc,
|
|
memRefDescriptor.stride(rewriter, loc, i))
|
|
: LLVM::ConstantOp::create(rewriter, loc, i32, stride);
|
|
increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
|
|
}
|
|
index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
|
|
: increment;
|
|
}
|
|
return index ? index : createI32Constant(rewriter, loc, 0);
|
|
}
|
|
|
|
/// Compute the contents of the `num_records` field for a given memref
|
|
/// descriptor - that is, the number of bytes that's one element past the
|
|
/// greatest possible valid index into the memref.
|
|
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
|
|
MemRefType memrefType,
|
|
MemRefDescriptor &memrefDescriptor,
|
|
ArrayRef<int64_t> strides,
|
|
int64_t elementByteWidth) {
|
|
if (memrefType.hasStaticShape() &&
|
|
!llvm::any_of(strides, ShapedType::isDynamic)) {
|
|
int64_t size = memrefType.getRank() == 0 ? 1 : 0;
|
|
ArrayRef<int64_t> shape = memrefType.getShape();
|
|
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
|
|
size = std::max(shape[i] * strides[i], size);
|
|
size = size * elementByteWidth;
|
|
return createI64Constant(rewriter, loc, size);
|
|
}
|
|
Value maxIndex;
|
|
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
|
|
Value size = memrefDescriptor.size(rewriter, loc, i);
|
|
Value stride = memrefDescriptor.stride(rewriter, loc, i);
|
|
Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
|
|
maxIndex = maxIndex
|
|
? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
|
|
: maxThisDim;
|
|
}
|
|
Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
|
|
Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
|
|
return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
|
|
}
|
|
|
|
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
|
|
Value basePointer, Value numRecords,
|
|
bool boundsCheck, amdgpu::Chipset chipset,
|
|
Value cacheSwizzleStride = nullptr,
|
|
unsigned addressSpace = 8) {
|
|
// The stride value is generally 0. However, on MI-300 and onward, you can
|
|
// enable a cache swizzling mode by setting bit 14 of the stride field
|
|
// and setting that stride to a cache stride.
|
|
Type i16 = rewriter.getI16Type();
|
|
Value stride;
|
|
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
|
|
Value cacheStrideZext =
|
|
LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
|
|
Value swizzleBit = LLVM::ConstantOp::create(
|
|
rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
|
|
stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
|
|
/*isDisjoint=*/true);
|
|
} else {
|
|
stride = LLVM::ConstantOp::create(rewriter, loc, i16,
|
|
rewriter.getI16IntegerAttr(0));
|
|
}
|
|
// Get the number of elements.
|
|
// Flag word:
|
|
// bits 0-11: dst sel, ignored by these intrinsics
|
|
// bits 12-14: data format (ignored, must be nonzero, 7=float)
|
|
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
|
|
// bit 19: In nested heap (0 here)
|
|
// bit 20: Behavior on unmap (0 means "return 0 / ignore")
|
|
// bits 21-22: Index stride for swizzles (N/A)
|
|
// bit 23: Add thread ID (0)
|
|
// bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
|
|
// bits 25-26: Reserved (0)
|
|
// bit 27: Buffer is non-volatile (CDNA only)
|
|
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
|
|
// none, 3 = either swizzles or testing against offset field) RDNA only
|
|
// bits 30-31: Type (must be 0)
|
|
uint32_t flags = (7 << 12) | (4 << 15);
|
|
if (chipset.majorVersion >= 10) {
|
|
flags |= (1 << 24);
|
|
uint32_t oob = boundsCheck ? 3 : 2;
|
|
flags |= (oob << 28);
|
|
}
|
|
Value flagsConst = createI32Constant(rewriter, loc, flags);
|
|
Type rsrcType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
|
|
Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
|
|
loc, rsrcType, basePointer, stride, numRecords, flagsConst);
|
|
return resource;
|
|
}
|
|
|
|
namespace {
|
|
struct FatRawBufferCastLowering
|
|
: public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
|
|
FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
|
|
chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Value memRef = adaptor.getSource();
|
|
Value unconvertedMemref = op.getSource();
|
|
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
|
|
MemRefDescriptor descriptor(memRef);
|
|
|
|
DataLayout dataLayout = DataLayout::closest(op);
|
|
int64_t elementByteWidth =
|
|
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
|
|
|
|
int64_t unusedOffset = 0;
|
|
SmallVector<int64_t, 5> strideVals;
|
|
if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
|
|
return op.emitOpError("Can't lower non-stride-offset memrefs");
|
|
|
|
Value numRecords = adaptor.getValidBytes();
|
|
if (!numRecords)
|
|
numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
|
|
strideVals, elementByteWidth);
|
|
|
|
Value basePointer =
|
|
adaptor.getResetOffset()
|
|
? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
|
|
memrefType)
|
|
: descriptor.alignedPtr(rewriter, loc);
|
|
|
|
Value offset = adaptor.getResetOffset()
|
|
? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
|
rewriter.getIndexAttr(0))
|
|
: descriptor.offset(rewriter, loc);
|
|
|
|
bool hasSizes = memrefType.getRank() > 0;
|
|
// No need to unpack() and pack() all the individual sizes and strides,
|
|
// so we'll just extract the arrays.
|
|
Value sizes = hasSizes
|
|
? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
|
|
kSizePosInMemRefDescriptor)
|
|
: Value{};
|
|
Value strides =
|
|
hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
|
|
kStridePosInMemRefDescriptor)
|
|
: Value{};
|
|
|
|
Value fatPtr = makeBufferRsrc(
|
|
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
|
|
chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
|
|
|
|
Value result = MemRefDescriptor::poison(
|
|
rewriter, loc,
|
|
getTypeConverter()->convertType(op.getResult().getType()));
|
|
SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
|
|
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
|
|
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
|
|
kAlignedPtrPosInMemRefDescriptor);
|
|
result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
|
|
kOffsetPosInMemRefDescriptor);
|
|
if (hasSizes) {
|
|
result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
|
|
kSizePosInMemRefDescriptor);
|
|
result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
|
|
kStridePosInMemRefDescriptor);
|
|
}
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Define lowering patterns for raw buffer ops
|
|
template <typename GpuOp, typename Intrinsic>
|
|
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
|
|
RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
static constexpr uint32_t maxVectorOpWidth = 128;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = gpuOp.getLoc();
|
|
Value memref = adaptor.getMemref();
|
|
Value unconvertedMemref = gpuOp.getMemref();
|
|
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
|
|
|
|
if (chipset.majorVersion < 9)
|
|
return gpuOp.emitOpError("raw buffer ops require GCN or higher");
|
|
|
|
Value storeData = adaptor.getODSOperands(0)[0];
|
|
if (storeData == memref) // no write component to this op
|
|
storeData = Value();
|
|
Type wantedDataType;
|
|
if (storeData)
|
|
wantedDataType = storeData.getType();
|
|
else
|
|
wantedDataType = gpuOp.getODSResults(0)[0].getType();
|
|
|
|
Value atomicCmpData = Value();
|
|
// Operand index 1 of a load is the indices, trying to read them can crash.
|
|
if (storeData) {
|
|
Value maybeCmpData = adaptor.getODSOperands(1)[0];
|
|
if (maybeCmpData != memref)
|
|
atomicCmpData = maybeCmpData;
|
|
}
|
|
|
|
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
|
|
|
|
Type i32 = rewriter.getI32Type();
|
|
|
|
// Get the type size in bytes.
|
|
DataLayout dataLayout = DataLayout::closest(gpuOp);
|
|
int64_t elementByteWidth =
|
|
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
|
|
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
|
|
|
|
// If we want to load a vector<NxT> with total size <= 32
|
|
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
|
|
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
|
|
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
|
|
// so bitcast any floats to integers.
|
|
Type llvmBufferValType = llvmWantedDataType;
|
|
if (atomicCmpData) {
|
|
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
|
|
llvmBufferValType = this->getTypeConverter()->convertType(
|
|
rewriter.getIntegerType(floatType.getWidth()));
|
|
}
|
|
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
|
|
uint32_t vecLen = dataVector.getNumElements();
|
|
uint32_t elemBits =
|
|
dataLayout.getTypeSizeInBits(dataVector.getElementType());
|
|
uint32_t totalBits = elemBits * vecLen;
|
|
bool usePackedFp16 =
|
|
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
|
|
if (totalBits > maxVectorOpWidth)
|
|
return gpuOp.emitOpError(
|
|
"Total width of loads or stores must be no more than " +
|
|
Twine(maxVectorOpWidth) + " bits, but we call for " +
|
|
Twine(totalBits) +
|
|
" bits. This should've been caught in validation");
|
|
if (!usePackedFp16 && elemBits < 32) {
|
|
if (totalBits > 32) {
|
|
if (totalBits % 32 != 0)
|
|
return gpuOp.emitOpError("Load or store of more than 32-bits that "
|
|
"doesn't fit into words. Can't happen\n");
|
|
llvmBufferValType = this->typeConverter->convertType(
|
|
VectorType::get(totalBits / 32, i32));
|
|
} else {
|
|
llvmBufferValType = this->typeConverter->convertType(
|
|
rewriter.getIntegerType(totalBits));
|
|
}
|
|
}
|
|
}
|
|
if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
|
|
// Buffer intrinsics doesn't support 1-element vectors, cast them to
|
|
// scalars.
|
|
if (vecType.getNumElements() == 1)
|
|
llvmBufferValType = vecType.getElementType();
|
|
}
|
|
|
|
SmallVector<Value, 6> args;
|
|
if (storeData) {
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
Value castForStore = LLVM::BitcastOp::create(
|
|
rewriter, loc, llvmBufferValType, storeData);
|
|
args.push_back(castForStore);
|
|
} else {
|
|
args.push_back(storeData);
|
|
}
|
|
}
|
|
|
|
if (atomicCmpData) {
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
Value castForCmp = LLVM::BitcastOp::create(
|
|
rewriter, loc, llvmBufferValType, atomicCmpData);
|
|
args.push_back(castForCmp);
|
|
} else {
|
|
args.push_back(atomicCmpData);
|
|
}
|
|
}
|
|
|
|
// Construct buffer descriptor from memref, attributes
|
|
int64_t offset = 0;
|
|
SmallVector<int64_t, 5> strides;
|
|
if (failed(memrefType.getStridesAndOffset(strides, offset)))
|
|
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
|
|
|
|
MemRefDescriptor memrefDescriptor(memref);
|
|
|
|
Value ptr = memrefDescriptor.bufferPtr(
|
|
rewriter, loc, *this->getTypeConverter(), memrefType);
|
|
Value numRecords = getNumRecords(
|
|
rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
|
|
Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
|
|
adaptor.getBoundsCheck(), chipset);
|
|
args.push_back(resource);
|
|
|
|
// Indexing (voffset)
|
|
Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
|
|
adaptor.getIndices(), strides);
|
|
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
|
|
indexOffset && *indexOffset > 0) {
|
|
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
|
|
voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
|
|
extraOffsetConst)
|
|
: extraOffsetConst;
|
|
}
|
|
voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
|
|
args.push_back(voffset);
|
|
|
|
// SGPR offset.
|
|
Value sgprOffset = adaptor.getSgprOffset();
|
|
if (!sgprOffset)
|
|
sgprOffset = createI32Constant(rewriter, loc, 0);
|
|
sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
|
|
args.push_back(sgprOffset);
|
|
|
|
// bit 0: GLC = 0 (atomics drop value, less coherency)
|
|
// bits 1-2: SLC, DLC = 0 (similarly)
|
|
// bit 3: swizzled (0 for raw)
|
|
args.push_back(createI32Constant(rewriter, loc, 0));
|
|
|
|
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
|
|
llvmBufferValType);
|
|
Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
|
|
ArrayRef<NamedAttribute>());
|
|
if (lowered->getNumResults() == 1) {
|
|
Value replacement = lowered->getResult(0);
|
|
if (llvmBufferValType != llvmWantedDataType) {
|
|
replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
|
|
replacement);
|
|
}
|
|
rewriter.replaceOp(gpuOp, replacement);
|
|
} else {
|
|
rewriter.eraseOp(gpuOp);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// TODO: AMDGPU backend already have all this bitpacking logic, we should move
|
|
// it to some common place.
|
|
/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
|
|
/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
|
|
/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
|
|
/// Vmcnt = Waitcnt[15:10] (gfx11)
|
|
/// Expcnt = Waitcnt[6:4] (pre-gfx11)
|
|
/// Expcnt = Waitcnt[2:0] (gfx11)
|
|
/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
|
|
/// Lgkmcnt = Waitcnt[13:8] (gfx10)
|
|
/// Lgkmcnt = Waitcnt[9:4] (gfx11)
|
|
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
|
|
unsigned expcnt, unsigned lgkmcnt) {
|
|
if (chipset.majorVersion < 9) {
|
|
vmcnt = std::min(15u, vmcnt);
|
|
expcnt = std::min(7u, expcnt);
|
|
lgkmcnt = std::min(15u, lgkmcnt);
|
|
return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
|
|
}
|
|
if (chipset.majorVersion == 9) {
|
|
vmcnt = std::min(63u, vmcnt);
|
|
expcnt = std::min(7u, expcnt);
|
|
lgkmcnt = std::min(15u, lgkmcnt);
|
|
unsigned lowBits = vmcnt & 0xF;
|
|
unsigned highBits = (vmcnt >> 4) << 14;
|
|
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
|
|
return lowBits | highBits | otherCnts;
|
|
}
|
|
if (chipset.majorVersion == 10) {
|
|
vmcnt = std::min(63u, vmcnt);
|
|
expcnt = std::min(7u, expcnt);
|
|
lgkmcnt = std::min(63u, lgkmcnt);
|
|
unsigned lowBits = vmcnt & 0xF;
|
|
unsigned highBits = (vmcnt >> 4) << 14;
|
|
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
|
|
return lowBits | highBits | otherCnts;
|
|
}
|
|
if (chipset.majorVersion == 11) {
|
|
vmcnt = std::min(63u, vmcnt);
|
|
expcnt = std::min(7u, expcnt);
|
|
lgkmcnt = std::min(63u, lgkmcnt);
|
|
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
struct MemoryCounterWaitOpLowering
|
|
: public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
|
|
MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
|
|
chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset.majorVersion >= 12) {
|
|
Location loc = op.getLoc();
|
|
if (std::optional<int> ds = adaptor.getDs())
|
|
ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
|
|
|
|
if (std::optional<int> load = adaptor.getLoad())
|
|
ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
|
|
|
|
if (std::optional<int> store = adaptor.getStore())
|
|
ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
|
|
|
|
if (std::optional<int> exp = adaptor.getExp())
|
|
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
|
|
|
|
if (std::optional<int> tensor = adaptor.getTensor())
|
|
ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
if (adaptor.getTensor())
|
|
return op.emitOpError("unsupported chipset");
|
|
|
|
auto getVal = [](Attribute attr) -> unsigned {
|
|
if (attr)
|
|
return cast<IntegerAttr>(attr).getInt();
|
|
|
|
// This value will be clamped to the maximum value for the chipset.
|
|
return 1024;
|
|
};
|
|
unsigned ds = getVal(adaptor.getDsAttr());
|
|
unsigned exp = getVal(adaptor.getExpAttr());
|
|
|
|
unsigned vmcnt = 1024;
|
|
Attribute load = adaptor.getLoadAttr();
|
|
Attribute store = adaptor.getStoreAttr();
|
|
if (load && store) {
|
|
vmcnt = getVal(load) + getVal(store);
|
|
} else if (load) {
|
|
vmcnt = getVal(load);
|
|
} else if (store) {
|
|
vmcnt = getVal(store);
|
|
}
|
|
|
|
FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
|
|
if (failed(waitcnt))
|
|
return op.emitOpError("unsupported chipset");
|
|
|
|
rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
|
|
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
// This ensures that waits on global memory aren't introduced on
|
|
// chips that don't have the BackOffBarrier feature enabled in LLVM.
|
|
bool requiresInlineAsm = chipset < kGfx90a;
|
|
|
|
Attribute mmra =
|
|
rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
|
|
// Note: while there *is* a workgroup-one-as scope, this, when combined with
|
|
// the MMRA, will lead to the fence having no effect. This is because the
|
|
// codepaths for an atomic load or store will observe that a
|
|
// one-address-space atomic to LDS requires no synchronization because
|
|
// operations on LDS are totally ordered with respect to each other, and so
|
|
// will not emit the correct waitcnt operations that these fences are
|
|
// intended to produce. Therefore, we use a broader type of fence and rely
|
|
// on the MMRA to relax it to the semantics we want.
|
|
StringRef scope = "workgroup";
|
|
|
|
auto relFence = LLVM::FenceOp::create(rewriter, loc,
|
|
LLVM::AtomicOrdering::release, scope);
|
|
relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
|
|
if (requiresInlineAsm) {
|
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
LLVM::AsmDialect::AD_ATT);
|
|
const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
|
|
const char *constraints = "";
|
|
LLVM::InlineAsmOp::create(
|
|
rewriter, loc,
|
|
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
|
|
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
|
|
/*is_align_stack=*/false, LLVM::TailCallKind::None,
|
|
/*asm_dialect=*/asmDialectAttr,
|
|
/*operand_attrs=*/ArrayAttr());
|
|
} else if (chipset.majorVersion < 12) {
|
|
ROCDL::SBarrierOp::create(rewriter, loc);
|
|
} else {
|
|
ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
|
|
ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
|
|
}
|
|
|
|
auto acqFence = LLVM::FenceOp::create(rewriter, loc,
|
|
LLVM::AtomicOrdering::acquire, scope);
|
|
acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
|
|
rewriter.replaceOp(op, acqFence);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
|
|
SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
|
|
(uint32_t)op.getOpts());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
|
|
/// and LLVM AMDGPU intrinsics convention.
|
|
///
|
|
/// Specifically:
|
|
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
|
|
/// allows bf16. Newer MFMAs support bf16 types on operand, check
|
|
/// IntrinsicsAMDGPU.td file for reference.
|
|
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
|
|
/// instead, which is what the f8f6f4 intrinsics use.
|
|
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
|
|
/// integer.
|
|
///
|
|
/// Note that the type of `input` has already been LLVM type converted:
|
|
/// therefore 8-bit and smaller floats are represented as their corresponding
|
|
/// `iN` integers.
|
|
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value input,
|
|
bool allowBf16 = true) {
|
|
Type inputType = input.getType();
|
|
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
|
|
if (vectorType.getElementType().isBF16() && !allowBf16)
|
|
return LLVM::BitcastOp::create(
|
|
rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
|
|
if (vectorType.getElementType().isInteger(8) &&
|
|
vectorType.getNumElements() <= 8)
|
|
return LLVM::BitcastOp::create(
|
|
rewriter, loc,
|
|
rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
|
|
if (isa<IntegerType>(vectorType.getElementType()) &&
|
|
vectorType.getElementTypeBitWidth() <= 8) {
|
|
int64_t numWords = llvm::divideCeil(
|
|
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
|
|
32);
|
|
return LLVM::BitcastOp::create(
|
|
rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
|
|
input);
|
|
}
|
|
}
|
|
return input;
|
|
}
|
|
|
|
/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
|
|
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
|
|
///
|
|
/// Specifically:
|
|
/// 1. If `input` is a i8 value, zero extend it to i32
|
|
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
|
|
///
|
|
/// Note that the type of `input` has already been LLVM type converted:
|
|
/// therefore 8-bit and smaller floats are represented as their corresponding
|
|
/// `iN` integers.
|
|
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value input) {
|
|
Type inputType = input.getType();
|
|
Type outputType = rewriter.getI32Type();
|
|
if (auto intType = dyn_cast<IntegerType>(inputType))
|
|
return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
|
|
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
|
|
}
|
|
|
|
/// Push an input operand. If it is a float type, nothing to do. If it is
|
|
/// an integer type, then we need to also push its signdness (1 for signed, 0
|
|
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
|
|
/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
|
|
/// We also need to convert bfloat inputs to i16 to account for the bfloat
|
|
/// intrinsics having been defined before the AMD backend supported bfloat. We
|
|
/// similarly need to pack 8-bit float types into integers as if they were i8
|
|
/// (which they are for the backend's purposes).
|
|
static void wmmaPushInputOperand(
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
|
|
Value mlirInput, SmallVectorImpl<Value> &operands,
|
|
SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
|
|
Type inputType = llvmInput.getType();
|
|
auto vectorType = dyn_cast<VectorType>(inputType);
|
|
if (!vectorType) {
|
|
operands.push_back(llvmInput);
|
|
return;
|
|
}
|
|
Type elemType = vectorType.getElementType();
|
|
if (elemType.getIntOrFloatBitWidth() > 8) {
|
|
operands.push_back(llvmInput);
|
|
return;
|
|
}
|
|
|
|
// We need to check the type of the input before conversion to properly test
|
|
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
|
|
// fp8/int8 information is lost during the conversion process.
|
|
auto mlirInputType = cast<VectorType>(mlirInput.getType());
|
|
bool isInputInteger = mlirInputType.getElementType().isInteger();
|
|
if (isInputInteger) {
|
|
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
|
|
bool localIsUnsigned = isUnsigned;
|
|
if (elemType.isUnsignedInteger()) {
|
|
localIsUnsigned = true;
|
|
} else if (elemType.isSignedInteger()) {
|
|
localIsUnsigned = false;
|
|
}
|
|
attrs.push_back(
|
|
NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
|
|
}
|
|
|
|
int64_t numBits =
|
|
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
|
|
Type i32 = rewriter.getI32Type();
|
|
Type intrinsicInType = numBits <= 32
|
|
? (Type)rewriter.getIntegerType(numBits)
|
|
: (Type)VectorType::get(numBits / 32, i32);
|
|
auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
|
|
Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
|
|
loc, llvmIntrinsicInType, llvmInput);
|
|
// The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
|
|
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
|
|
// Add in the zeros here.
|
|
if (numBits < 32)
|
|
castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
|
|
operands.push_back(castInput);
|
|
}
|
|
|
|
/// Push the output operand. For many cases this is only pushing the output in
|
|
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
|
|
/// since the same numbers of VGPRs is used, we need to decide if to store the
|
|
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
|
|
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
|
|
/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
|
|
/// as the instructions have been changed to return fewer registers instead.
|
|
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
|
|
Location loc,
|
|
const TypeConverter *typeConverter,
|
|
Value output, int32_t subwordOffset,
|
|
bool clamp, SmallVectorImpl<Value> &operands,
|
|
SmallVectorImpl<NamedAttribute> &attrs) {
|
|
Type inputType = output.getType();
|
|
auto vectorType = dyn_cast<VectorType>(inputType);
|
|
Type elemType = vectorType.getElementType();
|
|
operands.push_back(output);
|
|
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
|
|
attrs.push_back(
|
|
NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
|
|
} else if (elemType.isInteger(32)) {
|
|
attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
|
|
}
|
|
}
|
|
|
|
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
|
|
/// supported by the `_bf8` instructions on the given `chipset`.
|
|
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
|
|
return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
|
|
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
|
|
}
|
|
|
|
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
|
|
/// supported by the `_fp8` instructions on the given `chipset`.
|
|
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
|
|
return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
|
|
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
|
|
}
|
|
|
|
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
|
|
/// if one exists. This includes checking to ensure the intrinsic is supported
|
|
/// on the architecture you are compiling for.
|
|
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
|
|
Chipset chipset) {
|
|
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
|
|
b = mfma.getBlocks();
|
|
Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
|
|
Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
|
|
|
|
if (sourceElem.isF32() && destElem.isF32()) {
|
|
if (mfma.getReducePrecision() && chipset >= kGfx942) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 1 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x1f32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 1 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x1f32::getOperationName();
|
|
if (m == 4 && n == 4 && k == 1 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x1f32::getOperationName();
|
|
if (m == 32 && n == 32 && k == 2 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x2f32::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x4f32::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isF16() && destElem.isF32()) {
|
|
if (chipset >= kGfx950) {
|
|
if (m == 32 && n == 32 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x4f16::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x4f16::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x8f16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isBF16() && destElem.isF32()) {
|
|
if (chipset >= kGfx950) {
|
|
if (m == 32 && n == 32 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
|
|
}
|
|
if (chipset >= kGfx90a) {
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 2 && b == 2)
|
|
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 2 && b == 4)
|
|
return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
|
|
if (m == 4 && n == 4 && k == 2 && b == 16)
|
|
return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
|
|
if (m == 32 && n == 32 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
|
|
if (m == 16 && n == 16 && k == 8 && b == 1)
|
|
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
|
|
if (chipset >= kGfx950) {
|
|
if (m == 32 && n == 32 && k == 32 && b == 1)
|
|
return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 64 && b == 1)
|
|
return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 4 && b == 2)
|
|
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 4 && b == 4)
|
|
return ROCDL::mfma_i32_16x16x4i8::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 16)
|
|
return ROCDL::mfma_i32_4x4x4i8::getOperationName();
|
|
if (m == 32 && n == 32 && k == 8 && b == 1)
|
|
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 16 && b == 1)
|
|
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
|
|
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
|
|
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
|
|
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
|
|
}
|
|
|
|
if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
|
|
if (m == 16 && n == 16 && k == 4 && b == 1)
|
|
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
|
|
if (m == 4 && n == 4 && k == 4 && b == 4)
|
|
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
|
|
}
|
|
|
|
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
|
|
// Known to be correct because there are no scalar f8 instructions and
|
|
// because a length mismatch will have been caught by the verifier.
|
|
Type sourceBElem =
|
|
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1) {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
|
|
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 16 && b == 1) {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
|
|
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
|
|
}
|
|
}
|
|
|
|
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
|
|
Type sourceBElem =
|
|
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
|
|
if (m == 16 && n == 16 && k == 32 && b == 1) {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
|
|
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
|
|
}
|
|
if (m == 32 && n == 32 && k == 16 && b == 1) {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
|
|
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
|
|
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
|
|
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
|
|
.Case([](Float8E4M3FNType) { return 0u; })
|
|
.Case([](Float8E5M2Type) { return 1u; })
|
|
.Case([](Float6E2M3FNType) { return 2u; })
|
|
.Case([](Float6E3M2FNType) { return 3u; })
|
|
.Case([](Float4E2M1FNType) { return 4u; })
|
|
.Default(std::nullopt);
|
|
}
|
|
|
|
/// If there is a scaled MFMA instruction for the input element types `aType`
|
|
/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
|
|
/// blocks) on the given `chipset`, return a tuple consisting of the
|
|
/// OperationName of the intrinsic and the type codes that need to be passed to
|
|
/// that intrinsic. Note that this is also used to implement some un-scaled
|
|
/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
|
|
/// MFMA with a scale of 0.
|
|
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
|
|
mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
|
|
uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
|
|
aType = getElementTypeOrSelf(aType);
|
|
bType = getElementTypeOrSelf(bType);
|
|
destType = getElementTypeOrSelf(destType);
|
|
|
|
if (chipset < kGfx950)
|
|
return std::nullopt;
|
|
if (!isa<Float32Type>(destType))
|
|
return std::nullopt;
|
|
|
|
std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
|
|
std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
|
|
if (!aTypeCode || !bTypeCode)
|
|
return std::nullopt;
|
|
|
|
if (m == 32 && n == 32 && k == 64 && b == 1)
|
|
return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
|
|
*aTypeCode, *bTypeCode};
|
|
if (m == 16 && n == 16 && k == 128 && b == 1)
|
|
return std::tuple{
|
|
ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
|
|
*bTypeCode};
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
|
|
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
|
|
return mfmaOpToScaledIntrinsic(
|
|
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
|
|
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
|
|
mfma.getBlocks(), chipset);
|
|
}
|
|
|
|
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
|
|
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
|
|
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
|
|
smfma.getSourceB().getType(),
|
|
smfma.getDestC().getType(), smfma.getM(),
|
|
smfma.getN(), smfma.getK(), 1u, chipset);
|
|
}
|
|
|
|
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
|
|
/// for RDNA3/4 architectures.
|
|
static std::optional<StringRef>
|
|
wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
|
|
Type elemDestType, uint32_t k, bool isRDNA3) {
|
|
using fp8 = Float8E4M3FNType;
|
|
using bf8 = Float8E5M2Type;
|
|
|
|
// Handle k == 16 for RDNA3/4.
|
|
if (k == 16) {
|
|
// Common patterns for RDNA3 and RDNA4.
|
|
if (elemSourceType.isF16() && elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
|
|
if (elemSourceType.isBF16() && elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
|
|
if (elemSourceType.isF16() && elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
|
|
if (elemSourceType.isBF16() && elemDestType.isBF16())
|
|
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
|
|
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
|
|
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
|
|
|
|
// RDNA3 specific patterns.
|
|
if (isRDNA3) {
|
|
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
|
|
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
|
|
// RDNA4 specific patterns (fp8/bf8).
|
|
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
|
|
elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
|
|
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
|
|
elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
|
|
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
|
|
elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
|
|
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
|
|
elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
|
|
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
|
|
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Handle k == 32 for RDNA4.
|
|
if (k == 32 && !isRDNA3) {
|
|
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
|
|
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
|
|
/// for the gfx1250 architecture.
|
|
static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
|
|
Type elemBSourceType,
|
|
Type elemDestType,
|
|
uint32_t k) {
|
|
using fp8 = Float8E4M3FNType;
|
|
using bf8 = Float8E5M2Type;
|
|
|
|
if (k == 4) {
|
|
if (elemSourceType.isF32() && elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (k == 32) {
|
|
if (elemSourceType.isF16() && elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
|
|
if (elemSourceType.isBF16() && elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
|
|
if (elemSourceType.isF16() && elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
|
|
if (elemSourceType.isBF16() && elemDestType.isBF16())
|
|
return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (k == 64) {
|
|
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
|
|
}
|
|
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
|
|
}
|
|
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
|
|
}
|
|
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
|
|
}
|
|
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
|
|
return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (k == 128) {
|
|
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
|
|
}
|
|
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
|
|
}
|
|
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
|
|
}
|
|
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
|
|
if (elemDestType.isF32())
|
|
return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
|
|
if (elemDestType.isF16())
|
|
return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
|
|
/// if one exists. This includes checking to ensure the intrinsic is supported
|
|
/// on the architecture you are compiling for.
|
|
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
|
|
Chipset chipset) {
|
|
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
|
|
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
|
|
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
|
|
Type elemSourceType = sourceVectorType.getElementType();
|
|
Type elemBSourceType = sourceBVectorType.getElementType();
|
|
Type elemDestType = destVectorType.getElementType();
|
|
|
|
const uint32_t k = wmma.getK();
|
|
const bool isRDNA3 = chipset.majorVersion == 11;
|
|
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
|
|
|
|
// Handle RDNA3 and RDNA4.
|
|
if (isRDNA3 || isRDNA4)
|
|
return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
|
|
k, isRDNA3);
|
|
|
|
// Handle gfx1250.
|
|
if (chipset == kGfx1250)
|
|
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
|
|
elemDestType, k);
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
namespace {
|
|
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
|
|
MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type outType = typeConverter->convertType(op.getDestD().getType());
|
|
Type intrinsicOutType = outType;
|
|
if (auto outVecType = dyn_cast<VectorType>(outType))
|
|
if (outVecType.getElementType().isBF16())
|
|
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
|
|
|
|
if (chipset.majorVersion != 9 || chipset < kGfx908)
|
|
return op->emitOpError("MFMA only supported on gfx908+");
|
|
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
|
|
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
|
|
if (chipset < kGfx942)
|
|
return op.emitOpError("negation unsupported on older than gfx942");
|
|
getBlgpField |=
|
|
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
|
|
}
|
|
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
|
|
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
|
|
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
|
|
if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
|
|
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
|
|
|
|
bool isScaled =
|
|
!maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
|
|
if (isScaled &&
|
|
(adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
|
|
return op.emitOpError(
|
|
"non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
|
|
"be scaled as those fields are used for type information");
|
|
}
|
|
|
|
StringRef intrinsicName =
|
|
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
|
|
// Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
|
|
// allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
|
|
bool allowBf16 = [&]() {
|
|
if (chipset < kGfx950)
|
|
return false;
|
|
if (isScaled)
|
|
return true;
|
|
return intrinsicName.contains("16x16x32.bf16") ||
|
|
intrinsicName.contains("32x32x16.bf16");
|
|
}();
|
|
OperationState loweredOp(loc, intrinsicName);
|
|
loweredOp.addTypes(intrinsicOutType);
|
|
loweredOp.addOperands({convertMFMAVectorOperand(
|
|
rewriter, loc, adaptor.getSourceA(), allowBf16),
|
|
convertMFMAVectorOperand(
|
|
rewriter, loc, adaptor.getSourceB(), allowBf16),
|
|
adaptor.getDestC()});
|
|
if (isScaled) {
|
|
Value zero = createI32Constant(rewriter, loc, 0);
|
|
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
|
|
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
|
|
createI32Constant(rewriter, loc, bTypeCode),
|
|
/*scale A byte=*/zero, /*scale A=*/zero,
|
|
/*scale B byte=*/zero, /*scale B=*/zero});
|
|
} else {
|
|
loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
|
|
createI32Constant(rewriter, loc, op.getAbid()),
|
|
createI32Constant(rewriter, loc, getBlgpField)});
|
|
};
|
|
Value lowered = rewriter.create(loweredOp)->getResult(0);
|
|
if (outType != intrinsicOutType)
|
|
lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
|
|
rewriter.replaceOp(op, lowered);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
|
|
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
|
|
|
|
if (chipset.majorVersion != 9 || chipset < kGfx950)
|
|
return op->emitOpError("scaled MFMA only supported on gfx908+");
|
|
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
|
|
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
|
|
if (!maybeScaledIntrinsic.has_value())
|
|
return op.emitOpError(
|
|
"no intrinsic matching scaled MFMA size on given chipset");
|
|
|
|
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
|
|
OperationState loweredOp(loc, intrinsicName);
|
|
loweredOp.addTypes(intrinsicOutType);
|
|
loweredOp.addOperands(
|
|
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
|
|
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
|
|
adaptor.getDestC()});
|
|
Value scalesIdxA =
|
|
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
|
|
Value scalesIdxB =
|
|
createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
|
|
loweredOp.addOperands(
|
|
{createI32Constant(rewriter, loc, aTypeCode),
|
|
createI32Constant(rewriter, loc, bTypeCode),
|
|
/*scales idx A=*/scalesIdxA,
|
|
/*scales A*/
|
|
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
|
|
/*scales idx B=*/scalesIdxB,
|
|
/*scales B*/
|
|
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
|
|
Value lowered = rewriter.create(loweredOp)->getResult(0);
|
|
rewriter.replaceOp(op, lowered);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
|
|
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
auto outType =
|
|
typeConverter->convertType<VectorType>(op.getDestD().getType());
|
|
if (!outType)
|
|
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
|
|
|
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
|
|
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
|
|
|
|
bool isGFX1250 = chipset >= kGfx1250;
|
|
|
|
// The WMMA operations represent vectors of bf16s as vectors of i16s
|
|
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
|
|
// bitcast them back.
|
|
auto aType = cast<VectorType>(adaptor.getSourceA().getType());
|
|
auto bType = cast<VectorType>(adaptor.getSourceB().getType());
|
|
auto destCType = cast<VectorType>(adaptor.getDestC().getType());
|
|
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
|
|
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
|
|
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
|
|
bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
|
|
VectorType rawOutType = outType;
|
|
if (castOutToI16)
|
|
rawOutType = outType.clone(rewriter.getI16Type());
|
|
Value a = adaptor.getSourceA();
|
|
if (castAToI16)
|
|
a = LLVM::BitcastOp::create(rewriter, loc,
|
|
aType.clone(rewriter.getI16Type()), a);
|
|
Value b = adaptor.getSourceB();
|
|
if (castBToI16)
|
|
b = LLVM::BitcastOp::create(rewriter, loc,
|
|
bType.clone(rewriter.getI16Type()), b);
|
|
Value destC = adaptor.getDestC();
|
|
if (castDestCToI16)
|
|
destC = LLVM::BitcastOp::create(
|
|
rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
|
|
|
|
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
|
|
|
|
if (!maybeIntrinsic.has_value())
|
|
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
|
|
|
|
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
|
|
return op.emitOpError("subwordOffset not supported on gfx12+");
|
|
|
|
SmallVector<Value, 4> operands;
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
|
|
op.getSourceA(), operands, attrs, "signA");
|
|
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
|
|
op.getSourceB(), operands, attrs, "signB");
|
|
wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
|
|
op.getSubwordOffset(), op.getClamp(), operands,
|
|
attrs);
|
|
|
|
OperationState loweredOp(loc, *maybeIntrinsic);
|
|
loweredOp.addTypes(rawOutType);
|
|
loweredOp.addOperands(operands);
|
|
loweredOp.addAttributes(attrs);
|
|
Operation *lowered = rewriter.create(loweredOp);
|
|
|
|
Operation *maybeCastBack = lowered;
|
|
if (rawOutType != outType)
|
|
maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
|
|
lowered->getResult(0));
|
|
rewriter.replaceOp(op, maybeCastBack->getResults());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TransposeLoadOpLowering
|
|
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
|
|
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset != kGfx950)
|
|
return op.emitOpError("Non-gfx950 chipset not supported");
|
|
|
|
Location loc = op.getLoc();
|
|
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
|
|
|
|
// Elements in subbyte memrefs are stored non-contiguously,
|
|
// reject if source is sub-byte memref. Use emulated memrefs instead.
|
|
size_t srcElementSize =
|
|
srcMemRefType.getElementType().getIntOrFloatBitWidth();
|
|
if (srcElementSize < 8)
|
|
return op.emitOpError("Expect source memref to have at least 8 bits "
|
|
"element size, got ")
|
|
<< srcElementSize;
|
|
|
|
auto resultType = cast<VectorType>(op.getResult().getType());
|
|
Value srcPtr =
|
|
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
|
|
(adaptor.getSrcIndices()));
|
|
|
|
size_t numElements = resultType.getNumElements();
|
|
size_t elementTypeSize =
|
|
resultType.getElementType().getIntOrFloatBitWidth();
|
|
|
|
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
|
|
// the element size is smaller than 16 bits.
|
|
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
|
|
rewriter.getIntegerType(32));
|
|
Type llvmResultType = typeConverter->convertType(resultType);
|
|
|
|
switch (elementTypeSize) {
|
|
case 4: {
|
|
assert(numElements == 16);
|
|
auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
|
|
rocdlResultType, srcPtr);
|
|
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
|
break;
|
|
}
|
|
case 6: {
|
|
assert(numElements == 16);
|
|
auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
|
|
rocdlResultType, srcPtr);
|
|
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
|
break;
|
|
}
|
|
case 8: {
|
|
assert(numElements == 8);
|
|
auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
|
|
rocdlResultType, srcPtr);
|
|
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
|
break;
|
|
}
|
|
case 16: {
|
|
assert(numElements == 4);
|
|
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
|
|
srcPtr);
|
|
break;
|
|
}
|
|
default:
|
|
return op.emitOpError("Unsupported element size for transpose load");
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
|
|
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
|
|
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
|
|
return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
|
|
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
|
|
|
|
// TODO: instead of only transfering one element per thread, we could
|
|
// augment it to transfer multiple elements per thread by issuing multiple
|
|
// `global_load_lds` instructions.
|
|
Type transferType = op.getTransferType();
|
|
int loadWidth = [&]() -> int {
|
|
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
|
|
return (transferVectorType.getNumElements() *
|
|
transferVectorType.getElementTypeBitWidth()) /
|
|
8;
|
|
}
|
|
return transferType.getIntOrFloatBitWidth() / 8;
|
|
}();
|
|
|
|
// Currently only 1, 2, 4, 12 and 16 byte loads are supported.
|
|
if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
|
|
return op.emitOpError("chipset unsupported element size");
|
|
|
|
if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
|
|
return op.emitOpError("Gather to LDS instructions with 12-byte and "
|
|
"16-byte load widths are only supported on gfx950");
|
|
|
|
Value srcPtr =
|
|
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
|
|
(adaptor.getSrcIndices()));
|
|
Value dstPtr =
|
|
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
|
|
(adaptor.getDstIndices()));
|
|
|
|
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
|
|
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
|
|
/*offset=*/rewriter.getI32IntegerAttr(0),
|
|
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
|
|
ArrayAttr{});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
struct ExtPackedFp8OpLowering final
|
|
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
|
|
ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct ScaledExtPackedMatrixOpLowering final
|
|
: public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
|
|
ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ScaledExtPackedMatrixOp op,
|
|
ScaledExtPackedMatrixOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct PackedTrunc2xFp8OpLowering final
|
|
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
|
|
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct PackedStochRoundFp8OpLowering final
|
|
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
|
|
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PackedStochRoundFp8Op op,
|
|
PackedStochRoundFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct ScaledExtPackedOpLowering final
|
|
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
|
|
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct PackedScaledTruncOpLowering final
|
|
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
|
|
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
} // end namespace
|
|
|
|
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
|
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "Fp8 conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
Type v4i8 =
|
|
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
|
|
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
|
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
|
|
|
|
Value source = adaptor.getSource();
|
|
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
|
|
auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
|
|
Type sourceElemType = getElementTypeOrSelf(op.getSource());
|
|
// Extend to a v4i8
|
|
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
|
|
Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
|
|
if (!sourceVecType) {
|
|
longVec = LLVM::InsertElementOp::create(
|
|
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
|
|
} else {
|
|
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
|
|
Value idx = createI32Constant(rewriter, loc, i);
|
|
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
|
|
longVec =
|
|
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
|
|
}
|
|
}
|
|
source = longVec;
|
|
}
|
|
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
|
|
if (resultVecType) {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
|
|
op.getIndex());
|
|
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
|
|
op.getIndex());
|
|
}
|
|
} else {
|
|
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
|
|
op.getIndex());
|
|
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
|
|
op.getIndex());
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
|
|
int32_t firstScaleByte) {
|
|
// When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
|
|
// operations, the attributes blockSize, sourceType, scaleWaveHalf, and
|
|
// firstScaleByte are merged into a single attribute scaleSel. This is how
|
|
// those values are merged together. (Note: scaleWaveHalf isn't a high-level
|
|
// attribute but is derifed from firstScaleLane).
|
|
assert(llvm::is_contained({16, 32}, blockSize));
|
|
assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
|
|
|
|
const bool isFp8 = bitWidth == 8;
|
|
const bool isBlock16 = blockSize == 16;
|
|
|
|
if (!isFp8) {
|
|
int32_t bit0 = isBlock16;
|
|
assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
|
|
int32_t bit1 = (firstScaleByte == 2) << 1;
|
|
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
|
|
int32_t bit2 = scaleWaveHalf << 2;
|
|
return bit2 | bit1 | bit0;
|
|
}
|
|
|
|
int32_t bit0 = isBlock16;
|
|
// firstScaleByte is guaranteed to be defined by two bits.
|
|
assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
|
|
int32_t bits2and1 = firstScaleByte << 1;
|
|
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
|
|
int32_t bit3 = scaleWaveHalf << 3;
|
|
int32_t bits = bit3 | bits2and1 | bit0;
|
|
// These are invalid cases.
|
|
assert(!llvm::is_contained(
|
|
{0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
|
|
return bits;
|
|
}
|
|
|
|
static std::optional<StringRef>
|
|
scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
|
|
using fp4 = Float4E2M1FNType;
|
|
using fp8 = Float8E4M3FNType;
|
|
using bf8 = Float8E5M2Type;
|
|
using fp6 = Float6E2M3FNType;
|
|
using bf6 = Float6E3M2FNType;
|
|
if (isa<fp4>(srcElemType)) {
|
|
if (destElemType.isF16())
|
|
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
|
|
if (destElemType.isBF16())
|
|
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
|
|
if (destElemType.isF32())
|
|
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
if (isa<fp8>(srcElemType)) {
|
|
if (destElemType.isF16())
|
|
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
|
|
if (destElemType.isBF16())
|
|
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
|
|
if (destElemType.isF32())
|
|
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
if (isa<bf8>(srcElemType)) {
|
|
if (destElemType.isF16())
|
|
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
|
|
if (destElemType.isBF16())
|
|
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
|
|
if (destElemType.isF32())
|
|
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
if (isa<fp6>(srcElemType)) {
|
|
if (destElemType.isF16())
|
|
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
|
|
if (destElemType.isBF16())
|
|
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
|
|
if (destElemType.isF32())
|
|
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
if (isa<bf6>(srcElemType)) {
|
|
if (destElemType.isF16())
|
|
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
|
|
if (destElemType.isBF16())
|
|
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
|
|
if (destElemType.isF32())
|
|
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
|
|
return std::nullopt;
|
|
}
|
|
llvm_unreachable("invalid combination of element types for packed conversion "
|
|
"instructions");
|
|
}
|
|
|
|
LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
|
|
ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
using fp4 = Float4E2M1FNType;
|
|
using fp8 = Float8E4M3FNType;
|
|
using bf8 = Float8E5M2Type;
|
|
using fp6 = Float6E2M3FNType;
|
|
using bf6 = Float6E3M2FNType;
|
|
Location loc = op.getLoc();
|
|
if (chipset != kGfx1250) {
|
|
return rewriter.notifyMatchFailure(
|
|
loc,
|
|
"Scaled fp packed conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
}
|
|
// Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
|
|
// is being selected.
|
|
int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
|
|
int32_t firstScaleByte = op.getFirstScaleByte();
|
|
int32_t blockSize = op.getBlockSize();
|
|
auto sourceType = cast<VectorType>(op.getSource().getType());
|
|
auto srcElemType = cast<FloatType>(sourceType.getElementType());
|
|
unsigned bitWidth = srcElemType.getWidth();
|
|
|
|
auto targetType = cast<VectorType>(op.getResult().getType());
|
|
auto destElemType = cast<FloatType>(targetType.getElementType());
|
|
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Value source = adaptor.getSource();
|
|
Type llvmResultType = typeConverter->convertType(op.getResult().getType());
|
|
Type packedType = nullptr;
|
|
if (isa<fp4>(srcElemType)) {
|
|
packedType = i32;
|
|
packedType = getTypeConverter()->convertType(packedType);
|
|
} else if (isa<fp8, bf8>(srcElemType)) {
|
|
packedType = VectorType::get(2, i32);
|
|
packedType = getTypeConverter()->convertType(packedType);
|
|
} else if (isa<fp6, bf6>(srcElemType)) {
|
|
packedType = VectorType::get(3, i32);
|
|
packedType = getTypeConverter()->convertType(packedType);
|
|
} else {
|
|
llvm_unreachable("invalid element type for packed scaled ext");
|
|
}
|
|
|
|
if (!packedType || !llvmResultType) {
|
|
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
|
}
|
|
|
|
std::optional<StringRef> maybeIntrinsic =
|
|
scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
|
|
if (!maybeIntrinsic.has_value())
|
|
return op.emitOpError(
|
|
"no intrinsic matching packed scaled conversion on the given chipset");
|
|
|
|
int32_t scaleSel =
|
|
getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
|
|
Value castedScale =
|
|
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
|
|
Value castedSource =
|
|
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
|
|
|
|
OperationState loweredOp(loc, *maybeIntrinsic);
|
|
loweredOp.addTypes({llvmResultType});
|
|
loweredOp.addOperands({castedSource, castedScale});
|
|
|
|
SmallVector<NamedAttribute, 1> attrs;
|
|
attrs.push_back(
|
|
NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
|
|
|
|
loweredOp.addAttributes(attrs);
|
|
Operation *lowered = rewriter.create(loweredOp);
|
|
rewriter.replaceOp(op, lowered);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
|
|
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
if (chipset != kGfx950)
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "Scaled fp conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
|
|
|
Value source = adaptor.getSource();
|
|
Value scale = adaptor.getScale();
|
|
|
|
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
|
|
Type sourceElemType = sourceVecType.getElementType();
|
|
VectorType destVecType = cast<VectorType>(op.getResult().getType());
|
|
Type destElemType = destVecType.getElementType();
|
|
|
|
VectorType packedVecType;
|
|
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
|
|
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
|
|
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
|
|
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
|
|
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
|
|
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
|
|
} else {
|
|
llvm_unreachable("invalid element type for scaled ext");
|
|
}
|
|
|
|
// Extend to a packedVectorType
|
|
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
|
|
Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
|
|
if (!sourceVecType) {
|
|
longVec = LLVM::InsertElementOp::create(
|
|
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
|
|
} else {
|
|
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
|
|
Value idx = createI32Constant(rewriter, loc, i);
|
|
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
|
|
longVec =
|
|
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
|
|
}
|
|
}
|
|
source = longVec;
|
|
}
|
|
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
|
|
|
|
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
|
|
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
|
|
op, destVecType, i32Source, scale, op.getIndex());
|
|
else
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
|
|
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
if (chipset != kGfx950)
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "Scaled fp conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
Type v2i16 = getTypeConverter()->convertType(
|
|
VectorType::get(2, rewriter.getI16Type()));
|
|
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
|
|
|
Type resultType = op.getResult().getType();
|
|
Type resultElemType = getElementTypeOrSelf(resultType);
|
|
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
|
|
Type sourceElemType = sourceVecType.getElementType();
|
|
|
|
Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
|
|
|
|
Value source = adaptor.getSource();
|
|
Value scale = adaptor.getScale();
|
|
Value existing = adaptor.getExisting();
|
|
if (existing)
|
|
existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
|
|
else
|
|
existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
|
|
|
|
if (sourceVecType.getNumElements() < 2) {
|
|
Value c0 = createI32Constant(rewriter, loc, 0);
|
|
Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
|
|
VectorType v2 = VectorType::get(2, sourceElemType);
|
|
source = LLVM::ZeroOp::create(rewriter, loc, v2);
|
|
source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
|
|
}
|
|
|
|
Value sourceA, sourceB;
|
|
if (sourceElemType.isF32()) {
|
|
Value c0 = createI32Constant(rewriter, loc, 0);
|
|
Value c1 = createI32Constant(rewriter, loc, 1);
|
|
sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
|
|
sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
|
|
}
|
|
|
|
Value result;
|
|
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
|
|
existing, sourceA, sourceB,
|
|
scale, op.getIndex());
|
|
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkBf8F16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
|
|
existing, sourceA, sourceB,
|
|
scale, op.getIndex());
|
|
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp8F16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
|
|
existing, sourceA, sourceB,
|
|
scale, op.getIndex());
|
|
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp4F16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
|
|
result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
|
|
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
|
|
else
|
|
return failure();
|
|
|
|
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
|
op, getTypeConverter()->convertType(resultType), result);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
|
|
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "Fp8 conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
|
|
|
Type resultType = op.getResult().getType();
|
|
Type resultElemType = getElementTypeOrSelf(resultType);
|
|
|
|
Value sourceA = adaptor.getSourceA();
|
|
Value sourceB = adaptor.getSourceB();
|
|
if (!sourceB)
|
|
sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
|
|
Value existing = adaptor.getExisting();
|
|
if (existing)
|
|
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
|
|
else
|
|
existing = LLVM::UndefOp::create(rewriter, loc, i32);
|
|
|
|
Value result;
|
|
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
|
|
result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
|
|
existing, op.getWordIndex());
|
|
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
|
|
result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
|
|
existing, op.getWordIndex());
|
|
|
|
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
|
op, getTypeConverter()->convertType(resultType), result);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
|
|
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "Fp8 conversion instructions are not available on target "
|
|
"architecture and their emulation is not implemented");
|
|
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
|
|
|
|
Type resultType = op.getResult().getType();
|
|
Type resultElemType = getElementTypeOrSelf(resultType);
|
|
|
|
Value source = adaptor.getSource();
|
|
Value stoch = adaptor.getStochiasticParam();
|
|
Value existing = adaptor.getExisting();
|
|
if (existing)
|
|
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
|
|
else
|
|
existing = LLVM::UndefOp::create(rewriter, loc, i32);
|
|
|
|
Value result;
|
|
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
|
|
result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
|
|
existing, op.getStoreIndex());
|
|
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
|
|
result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
|
|
existing, op.getStoreIndex());
|
|
|
|
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
|
op, getTypeConverter()->convertType(resultType), result);
|
|
return success();
|
|
}
|
|
|
|
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
|
|
// operation into the corresponding ROCDL instructions.
|
|
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
|
|
AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Convert the source operand to the corresponding LLVM type
|
|
Location loc = DppOp.getLoc();
|
|
Value src = adaptor.getSrc();
|
|
Value old = adaptor.getOld();
|
|
Type srcType = src.getType();
|
|
Type oldType = old.getType();
|
|
Type llvmType = nullptr;
|
|
if (srcType.getIntOrFloatBitWidth() < 32) {
|
|
llvmType = rewriter.getI32Type();
|
|
} else if (isa<FloatType>(srcType)) {
|
|
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
|
|
? rewriter.getF32Type()
|
|
: rewriter.getF64Type();
|
|
} else if (isa<IntegerType>(srcType)) {
|
|
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
|
|
? rewriter.getI32Type()
|
|
: rewriter.getI64Type();
|
|
}
|
|
auto llvmSrcIntType = typeConverter->convertType(
|
|
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
|
|
|
|
// If the source type is less of 32, use bitcast to convert it to i32.
|
|
auto convertOperand = [&](Value operand, Type operandType) {
|
|
if (operandType.getIntOrFloatBitWidth() <= 16) {
|
|
if (llvm::isa<FloatType>(operandType)) {
|
|
operand =
|
|
LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
|
|
}
|
|
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
|
|
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
|
|
Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
|
|
operand =
|
|
LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
|
|
createI32Constant(rewriter, loc, 0));
|
|
operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
|
|
}
|
|
return operand;
|
|
};
|
|
|
|
src = convertOperand(src, srcType);
|
|
old = convertOperand(old, oldType);
|
|
|
|
// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
|
|
enum DppCtrl : unsigned {
|
|
ROW_SHL0 = 0x100,
|
|
ROW_SHR0 = 0x110,
|
|
ROW_ROR0 = 0x120,
|
|
WAVE_SHL1 = 0x130,
|
|
WAVE_ROL1 = 0x134,
|
|
WAVE_SHR1 = 0x138,
|
|
WAVE_ROR1 = 0x13C,
|
|
ROW_MIRROR = 0x140,
|
|
ROW_HALF_MIRROR = 0x141,
|
|
BCAST15 = 0x142,
|
|
BCAST31 = 0x143,
|
|
};
|
|
|
|
auto kind = DppOp.getKind();
|
|
auto permArgument = DppOp.getPermArgument();
|
|
uint32_t DppCtrl = 0;
|
|
|
|
switch (kind) {
|
|
|
|
case DPPPerm::quad_perm:
|
|
if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
|
|
int32_t i = 0;
|
|
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
|
|
uint32_t num = elem.getInt();
|
|
DppCtrl |= num << (i * 2);
|
|
i++;
|
|
}
|
|
}
|
|
break;
|
|
case DPPPerm::row_shl:
|
|
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
|
|
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
|
|
}
|
|
break;
|
|
case DPPPerm::row_shr:
|
|
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
|
|
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
|
|
}
|
|
break;
|
|
case DPPPerm::row_ror:
|
|
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
|
|
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
|
|
}
|
|
break;
|
|
case DPPPerm::wave_shl:
|
|
DppCtrl = DppCtrl::WAVE_SHL1;
|
|
break;
|
|
case DPPPerm::wave_shr:
|
|
DppCtrl = DppCtrl::WAVE_SHR1;
|
|
break;
|
|
case DPPPerm::wave_rol:
|
|
DppCtrl = DppCtrl::WAVE_ROL1;
|
|
break;
|
|
case DPPPerm::wave_ror:
|
|
DppCtrl = DppCtrl::WAVE_ROR1;
|
|
break;
|
|
case DPPPerm::row_mirror:
|
|
DppCtrl = DppCtrl::ROW_MIRROR;
|
|
break;
|
|
case DPPPerm::row_half_mirror:
|
|
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
|
|
break;
|
|
case DPPPerm::row_bcast_15:
|
|
DppCtrl = DppCtrl::BCAST15;
|
|
break;
|
|
case DPPPerm::row_bcast_31:
|
|
DppCtrl = DppCtrl::BCAST31;
|
|
break;
|
|
}
|
|
|
|
// Check for row_mask, bank_mask, bound_ctrl if they exist and create
|
|
// constants
|
|
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
|
|
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
|
|
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
|
|
|
|
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
|
|
auto dppMovOp =
|
|
ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
|
|
rowMask, bankMask, boundCtrl);
|
|
|
|
Value result = dppMovOp.getRes();
|
|
if (srcType.getIntOrFloatBitWidth() < 32) {
|
|
result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
|
|
if (!llvm::isa<IntegerType>(srcType)) {
|
|
result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
|
|
}
|
|
}
|
|
|
|
// We are replacing the AMDGPU_DPPOp instruction with the new
|
|
// ROCDL_DPPMovOp instruction
|
|
rewriter.replaceOp(DppOp, ValueRange(result));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AMDGPUSwizzleBitModeLowering
|
|
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type i32 = rewriter.getI32Type();
|
|
Value src = adaptor.getSrc();
|
|
SmallVector<Value> decomposed =
|
|
LLVM::decomposeValue(rewriter, loc, src, i32);
|
|
unsigned andMask = op.getAndMask();
|
|
unsigned orMask = op.getOrMask();
|
|
unsigned xorMask = op.getXorMask();
|
|
|
|
// bit 15 is 0 for the BitMode swizzle.
|
|
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
|
|
unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
|
|
Value maskValue = createI32Constant(rewriter, loc, mask);
|
|
SmallVector<Value> swizzled;
|
|
for (Value v : decomposed) {
|
|
Value res =
|
|
ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
|
|
swizzled.emplace_back(res);
|
|
}
|
|
|
|
Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset < kGfx950)
|
|
return op->emitOpError("permlane_swap is only supported on gfx950+");
|
|
|
|
Location loc = op.getLoc();
|
|
Type i32 = rewriter.getI32Type();
|
|
Value src = adaptor.getSrc();
|
|
unsigned rowLength = op.getRowLength();
|
|
bool fi = op.getFetchInactive();
|
|
bool boundctrl = op.getBoundCtrl();
|
|
|
|
SmallVector<Value> decomposed =
|
|
LLVM::decomposeValue(rewriter, loc, src, i32);
|
|
|
|
SmallVector<Value> permuted;
|
|
for (Value v : decomposed) {
|
|
Value res;
|
|
Type i32pair = LLVM::LLVMStructType::getLiteral(
|
|
rewriter.getContext(), {v.getType(), v.getType()});
|
|
|
|
if (rowLength == 16)
|
|
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
|
|
boundctrl);
|
|
else if (rowLength == 32)
|
|
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
|
|
boundctrl);
|
|
else
|
|
llvm_unreachable("unsupported row length");
|
|
|
|
Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
|
|
Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
|
|
|
|
Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
|
|
LLVM::ICmpPredicate::eq, vdst0, v);
|
|
|
|
// Per `permlane(16|32)` semantics: if the first extracted element equals
|
|
// 'v', the result is the second element; otherwise it is the first.
|
|
Value vdstNew =
|
|
LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
|
|
permuted.emplace_back(vdstNew);
|
|
}
|
|
|
|
Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
|
|
Value accumulator, Value value, int64_t shift) {
|
|
shift = shift % 32;
|
|
Value shiftAmount;
|
|
if (shift != 0) {
|
|
shiftAmount = createI32Constant(rewriter, loc, shift % 32);
|
|
value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
|
|
}
|
|
|
|
if (matchPattern(accumulator, mlir::m_Zero()))
|
|
return value;
|
|
|
|
constexpr bool isDisjoint = true;
|
|
return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
|
|
}
|
|
|
|
template <typename BaseOp>
|
|
struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
|
|
using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
|
|
using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
|
|
|
|
AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
|
: ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(BaseOp op, Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset < kGfx1250)
|
|
return op->emitOpError("make_dma_base is only supported on gfx1250");
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
constexpr int32_t constlen = 4;
|
|
Value consts[constlen];
|
|
for (int64_t i = 0; i < constlen; i++)
|
|
consts[i] = createI32Constant(rewriter, loc, i);
|
|
|
|
constexpr int32_t sgprslen = constlen;
|
|
Value sgprs[sgprslen];
|
|
for (int64_t i = 0; i < sgprslen; i++) {
|
|
sgprs[i] = consts[0];
|
|
}
|
|
|
|
sgprs[0] = consts[1];
|
|
|
|
if (op.isGather()) {
|
|
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
|
|
|
|
auto type = cast<TDMGatherBaseType>(op.getResult().getType());
|
|
Type indexType = type.getIndexType();
|
|
unsigned indexSize = indexType.getIntOrFloatBitWidth();
|
|
assert(llvm::is_contained({16u, 32u}, indexSize) &&
|
|
"expected index_size to be 16 or 32");
|
|
unsigned idx = (indexSize / 16) - 1;
|
|
|
|
if (idx)
|
|
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
|
|
}
|
|
|
|
ValueRange ldsIndices = adaptor.getLdsIndices();
|
|
Value lds = adaptor.getLds();
|
|
auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
|
|
|
|
Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
|
|
rewriter, loc, ldsMemRefType, lds, ldsIndices);
|
|
|
|
ValueRange globalIndices = adaptor.getGlobalIndices();
|
|
Value global = adaptor.getGlobal();
|
|
auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
|
|
|
|
Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
|
|
rewriter, loc, globalMemRefType, global, globalIndices);
|
|
|
|
Type i32 = rewriter.getI32Type();
|
|
Type i64 = rewriter.getI64Type();
|
|
|
|
sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
|
|
Value castForGlobalAddr =
|
|
LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
|
|
|
|
sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
|
|
|
|
Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
|
|
createI64Constant(rewriter, loc, 32));
|
|
|
|
Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
|
|
|
|
Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
|
|
highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
|
|
|
|
sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
|
|
|
|
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
|
|
assert(v4i32 && "expected type conversion to succeed");
|
|
Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
|
|
|
|
for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
|
|
result =
|
|
LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AMDGPUMakeDmaDescriptorLowering
|
|
: public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
|
|
Chipset chipset)
|
|
: ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
|
|
chipset(chipset) {}
|
|
Chipset chipset;
|
|
|
|
Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
|
|
|
|
Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0) const {
|
|
Value mask = op.getWorkgroupMask();
|
|
if (!mask)
|
|
return sgpr0;
|
|
|
|
Type i16 = rewriter.getI16Type();
|
|
mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
|
|
Type i32 = rewriter.getI32Type();
|
|
Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
|
|
return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
|
|
}
|
|
|
|
Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
unsigned elementTypeWidthInBits = op.getElementTypeWidth();
|
|
assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
|
|
"expected type width to be 8, 16, 32, or 64.");
|
|
int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
|
|
Value size = consts[idx];
|
|
return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
|
|
}
|
|
|
|
Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!adaptor.getAtomicBarrierAddress())
|
|
return sgpr0;
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
|
|
}
|
|
|
|
Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!adaptor.getGlobalIncrement())
|
|
return sgpr0;
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
|
|
}
|
|
|
|
Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!op.getPadAmount())
|
|
return sgpr0;
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
|
|
}
|
|
|
|
Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!op.getWorkgroupMask())
|
|
return sgpr0;
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
|
|
}
|
|
|
|
Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!op.getPadAmount())
|
|
return sgpr0;
|
|
|
|
// pre-condition: padInterval can be a power of two between 2 and 256.
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// If the pre-condition fails, there is a possibility of
|
|
// affecting the higher bits. In a following PR implement
|
|
// RuntimeVerifiableOpInterface that instruments conditions that need to be
|
|
// checked at runtime.
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Value padInterval = adaptor.getPadInterval();
|
|
padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
|
|
padInterval, false);
|
|
padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
|
|
// post-condition: padInterval can be a value between 0 and 7.
|
|
return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
|
|
}
|
|
|
|
Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
if (!op.getPadAmount())
|
|
return sgpr0;
|
|
|
|
// pre-condition: padAmount is a value between 1-128.
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// If the pre-condition fails, there is a possibility of
|
|
// affecting the higher bits. In a following PR implement
|
|
// RuntimeVerifiableOpInterface that instruments conditions that need to be
|
|
// checked at runtime.
|
|
Value padAmount = adaptor.getPadAmount();
|
|
padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
|
|
// post-condition: padAmount is a value between 0-127.
|
|
return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
|
|
}
|
|
|
|
Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr1,
|
|
ArrayRef<Value> consts) const {
|
|
if (!adaptor.getAtomicBarrierAddress())
|
|
return sgpr1;
|
|
|
|
Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
|
|
auto barrierAddressTy =
|
|
cast<MemRefType>(op.getAtomicBarrierAddress().getType());
|
|
ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
|
|
atomicBarrierAddress =
|
|
getStridedElementPtr(rewriter, loc, barrierAddressTy,
|
|
atomicBarrierAddress, atomicBarrierIndices);
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
// pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
|
|
// that the 3 LSBs are zero.
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// In a following PR implement RuntimeVerifiableOpInterface
|
|
// that instruments conditions that need to be checked at runtime.
|
|
atomicBarrierAddress =
|
|
LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
|
|
atomicBarrierAddress =
|
|
LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
|
|
Value mask = createI32Constant(rewriter, loc, 0xFFFF);
|
|
atomicBarrierAddress =
|
|
LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
|
|
return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
|
|
}
|
|
|
|
std::pair<Value, Value> setTensorDimX(MakeDmaDescriptorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr1, Value sgpr2,
|
|
ArrayRef<Value> consts, uint64_t dimX,
|
|
uint32_t offset) const {
|
|
ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
|
|
ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
|
|
SmallVector<OpFoldResult> mixedGlobalSizes =
|
|
getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
|
|
if (mixedGlobalSizes.size() <= dimX)
|
|
return {sgpr1, sgpr2};
|
|
|
|
OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
|
|
// pre-condition: tensorDimX is less than 2^32-1
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// In a following PR implement RuntimeVerifiableOpInterface that instruments
|
|
// conditions that need to be checked at runtime. This could also be fixed
|
|
// by saying that mixedGlobalSizes is a DynamicI32List.
|
|
Value tensorDimX;
|
|
if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
|
|
tensorDimX =
|
|
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
|
|
else {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
tensorDimX = cast<Value>(tensorDimXOpFoldResult);
|
|
tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
|
|
}
|
|
|
|
sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
|
|
|
|
Value c16 = createI32Constant(rewriter, loc, 16);
|
|
Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
|
|
sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
|
|
return {sgpr1, sgpr2};
|
|
}
|
|
|
|
std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr1, Value sgpr2,
|
|
ArrayRef<Value> consts) const {
|
|
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
|
|
48);
|
|
}
|
|
|
|
std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr2, Value sgpr3,
|
|
ArrayRef<Value> consts) const {
|
|
return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
|
|
80);
|
|
}
|
|
|
|
Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr, ArrayRef<Value> consts, size_t dimX,
|
|
int64_t offset) const {
|
|
ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
|
|
ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
|
|
SmallVector<OpFoldResult> mixedSharedSizes =
|
|
getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
|
|
if (mixedSharedSizes.size() <= dimX)
|
|
return sgpr;
|
|
|
|
OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
|
|
// pre-condition: tileDimX is less than 2^16-1
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// If the pre-condition fails, there is a possibility of
|
|
// affecting the higher bits. In a following PR implement
|
|
// RuntimeVerifiableOpInterface that instruments conditions that need to be
|
|
// checked at runtime. This could also be fixed by saying that
|
|
// mixedSharedSizes is a DynamicI16List.
|
|
Value tileDimX;
|
|
if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
|
|
tileDimX =
|
|
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
|
|
else {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
tileDimX = cast<Value>(tileDimXOpFoldResult);
|
|
tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
|
|
}
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
|
|
}
|
|
|
|
Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr3, ArrayRef<Value> consts) const {
|
|
return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
|
|
}
|
|
|
|
Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr4, ArrayRef<Value> consts) const {
|
|
return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
|
|
}
|
|
|
|
Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr4, ArrayRef<Value> consts) const {
|
|
return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
|
|
}
|
|
|
|
std::pair<Value, Value>
|
|
setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgprY, Value sgprZ, ArrayRef<Value> consts,
|
|
size_t dimX, int64_t offset) const {
|
|
ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
|
|
ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
|
|
SmallVector<OpFoldResult> mixedGlobalStrides =
|
|
getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
|
|
|
|
if (mixedGlobalStrides.size() <= dimX)
|
|
return {sgprY, sgprZ};
|
|
|
|
OpFoldResult tensorDimXStrideOpFoldResult =
|
|
*(mixedGlobalStrides.rbegin() + dimX);
|
|
// pre-condition: tensorDimXStride is less than 2^48-1
|
|
// TODO: Validation if the value breaks the pre-condition.
|
|
// In a following PR implement RuntimeVerifiableOpInterface that instruments
|
|
// conditions that need to be checked at runtime.
|
|
Value tensorDimXStride;
|
|
if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
|
|
tensorDimXStride =
|
|
createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
|
|
else
|
|
tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
|
|
|
|
constexpr int64_t first48bits = (1ll << 48) - 1;
|
|
Value mask = createI64Constant(rewriter, loc, first48bits);
|
|
tensorDimXStride =
|
|
LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Value tensorDimXStrideLow =
|
|
LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
|
|
sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
|
|
|
|
int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
|
|
Value shiftVal = createI64Constant(rewriter, loc, shift);
|
|
Value tensorDimXStrideHigh =
|
|
LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
|
|
tensorDimXStrideHigh =
|
|
LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
|
|
sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
|
|
offset + shift);
|
|
return {sgprY, sgprZ};
|
|
}
|
|
|
|
std::pair<Value, Value>
|
|
setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
|
|
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
|
|
0, 160);
|
|
}
|
|
|
|
std::pair<Value, Value>
|
|
setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
|
|
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
|
|
1, 208);
|
|
}
|
|
|
|
Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<Value> consts) const {
|
|
Value sgprs[8];
|
|
for (int64_t i = 0; i < 8; i++) {
|
|
sgprs[i] = consts[0];
|
|
}
|
|
|
|
sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
|
|
sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
|
|
sgprs[1] =
|
|
setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
|
|
std::tie(sgprs[1], sgprs[2]) =
|
|
setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
|
|
std::tie(sgprs[2], sgprs[3]) =
|
|
setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
|
|
|
|
sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
|
|
sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
|
|
sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
|
|
std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
|
|
op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
|
|
std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
|
|
op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
|
|
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
|
|
assert(v8i32 && "expected type conversion to succeed");
|
|
Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
|
|
|
|
for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
|
|
dgroup1 =
|
|
LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
|
|
}
|
|
|
|
return dgroup1;
|
|
}
|
|
|
|
Value setTensorDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
|
|
int64_t offset) const {
|
|
ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
|
|
ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
|
|
SmallVector<OpFoldResult> mixedGlobalSizes =
|
|
getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
|
|
if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
|
|
return sgpr0;
|
|
|
|
OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
|
|
Value tensorDimX;
|
|
if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
|
|
tensorDimX =
|
|
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
|
|
else {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
tensorDimX = cast<Value>(tensorDimXOpFoldResult);
|
|
tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
|
|
}
|
|
|
|
return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
|
|
}
|
|
|
|
Value setTensorDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, ArrayRef<Value> consts) const {
|
|
return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
|
|
}
|
|
|
|
Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value accumulator,
|
|
Value value, int64_t shift) const {
|
|
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
value = LLVM::TruncOp::create(rewriter, loc, i32, value);
|
|
return setValueAtOffset(rewriter, loc, accumulator, value, shift);
|
|
}
|
|
|
|
Value setLDSAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr1, ArrayRef<Value> consts,
|
|
int64_t offset) const {
|
|
Value ldsAddrIncrement = adaptor.getLdsIncrement();
|
|
return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
|
|
}
|
|
|
|
std::pair<Value, Value>
|
|
setGlobalAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
|
|
int64_t offset) const {
|
|
Value globalAddrIncrement = adaptor.getGlobalIncrement();
|
|
sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
|
|
globalAddrIncrement, offset);
|
|
Value shift = createI64Constant(rewriter, loc, 32);
|
|
globalAddrIncrement =
|
|
LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
|
|
constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
|
|
sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
|
|
globalAddrIncrement, offset + 32);
|
|
Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
|
|
sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
|
|
return {sgpr2, sgpr3};
|
|
}
|
|
|
|
Value setTensorDim3OrLDSAddrIncrement(MakeDmaDescriptorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr1,
|
|
ArrayRef<Value> consts) const {
|
|
Value ldsIncrement = op.getLdsIncrement();
|
|
constexpr int64_t dim = 3;
|
|
constexpr int64_t offset = 32;
|
|
if (!ldsIncrement)
|
|
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
|
|
offset);
|
|
return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
|
|
offset);
|
|
}
|
|
|
|
std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
|
|
MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc, Value sgpr2,
|
|
Value sgpr3, ArrayRef<Value> consts) const {
|
|
Value globalIncrement = op.getGlobalIncrement();
|
|
constexpr int32_t dim = 2;
|
|
constexpr int32_t offset = 64;
|
|
if (!globalIncrement)
|
|
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
|
|
consts, dim, offset);
|
|
return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
|
|
consts, offset);
|
|
}
|
|
|
|
Value setIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr3, ArrayRef<Value> consts,
|
|
int32_t offset) const {
|
|
Value iterationCount = adaptor.getIterationCount();
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
// pre-condition: iterationCount is in the inclusive interval [1, 256].
|
|
// TODO: validation if the value breaks the pre-condition.
|
|
// If the pre-condition fails, there is a possibility of
|
|
// affecting the higher bits. In a following PR implement
|
|
// RuntimeVerifiableOpInterface that instruments conditions that need to be
|
|
// checked at runtime.
|
|
iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
|
|
iterationCount =
|
|
LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
|
|
return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
|
|
}
|
|
|
|
Value setTileDim3OrIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr3,
|
|
ArrayRef<Value> consts) const {
|
|
Value iterateCount = op.getIterationCount();
|
|
constexpr int32_t dim = 2;
|
|
constexpr int32_t offset = 112;
|
|
if (!iterateCount)
|
|
return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
|
|
offset);
|
|
|
|
return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
|
|
}
|
|
|
|
Value getDGroup2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<Value> consts) const {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
|
|
assert(v4i32 && "expected type conversion to succeed.");
|
|
|
|
bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
|
|
if (onlyNeedsTwoDescriptors)
|
|
return LLVM::ZeroOp::create(rewriter, loc, v4i32);
|
|
|
|
constexpr int64_t sgprlen = 4;
|
|
Value sgprs[sgprlen];
|
|
for (int i = 0; i < sgprlen; i++)
|
|
sgprs[i] = consts[0];
|
|
|
|
sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
|
|
sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
|
|
sgprs[1], consts);
|
|
std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
|
|
op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
|
|
sgprs[3] =
|
|
setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
|
|
|
|
Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
|
|
for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
|
|
dgroup2 =
|
|
LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
|
|
|
|
return dgroup2;
|
|
}
|
|
|
|
std::pair<Value, Value>
|
|
setTensorDim3Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
|
|
constexpr int32_t dim = 3;
|
|
constexpr int32_t offset = 0;
|
|
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
|
|
dim, offset);
|
|
}
|
|
|
|
std::pair<Value, Value> setTensorDim4(MakeDmaDescriptorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sgpr1, Value sgpr2,
|
|
ArrayRef<Value> consts) const {
|
|
constexpr int32_t dim = 4;
|
|
constexpr int32_t offset = 48;
|
|
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
|
|
offset);
|
|
}
|
|
|
|
Value setTileDim4(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
Value sgpr2, ArrayRef<Value> consts) const {
|
|
constexpr int32_t dim = 4;
|
|
constexpr int32_t offset = 80;
|
|
return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
|
|
}
|
|
|
|
Value getDGroup3(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<Value> consts) const {
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
|
|
assert(v4i32 && "expected type conversion to succeed.");
|
|
bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
|
|
if (onlyNeedsTwoDescriptors)
|
|
return LLVM::ZeroOp::create(rewriter, loc, v4i32);
|
|
|
|
constexpr int32_t sgprlen = 4;
|
|
Value sgprs[sgprlen];
|
|
for (int i = 0; i < sgprlen; i++)
|
|
sgprs[i] = consts[0];
|
|
|
|
std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
|
|
op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
|
|
std::tie(sgprs[1], sgprs[2]) =
|
|
setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
|
|
sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
|
|
|
|
Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
|
|
for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
|
|
dgroup3 =
|
|
LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
|
|
|
|
return dgroup3;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (chipset < kGfx1250)
|
|
return op->emitOpError(
|
|
"make_dma_descriptor is only supported on gfx1250");
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
IntegerType i32 = rewriter.getI32Type();
|
|
[[maybe_unused]] Type v4i32 =
|
|
this->typeConverter->convertType(VectorType::get(4, i32));
|
|
assert(v4i32 && "expected type conversion to succeed");
|
|
|
|
SmallVector<Value> consts;
|
|
for (int64_t i = 0; i < 8; i++)
|
|
consts.push_back(createI32Constant(rewriter, loc, i));
|
|
|
|
Value dgroup0 = this->getDGroup0(adaptor);
|
|
Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
|
|
Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
|
|
Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
|
|
SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
|
|
rewriter.replaceOpWithMultiple(op, {results});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertAMDGPUToROCDLPass
|
|
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
|
|
if (failed(maybeChipset)) {
|
|
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
|
|
return signalPassFailure();
|
|
}
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
LLVMTypeConverter converter(ctx);
|
|
|
|
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
|
|
amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
|
|
LLVMConversionTarget target(getContext());
|
|
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
|
|
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
|
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::amdgpu::populateCommonGPUTypeAndAttributeConversions(
|
|
TypeConverter &typeConverter) {
|
|
populateGpuMemorySpaceAttributeConversions(
|
|
typeConverter, [](gpu::AddressSpace space) {
|
|
switch (space) {
|
|
case gpu::AddressSpace::Global:
|
|
return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
|
|
case gpu::AddressSpace::Workgroup:
|
|
return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
|
|
case gpu::AddressSpace::Private:
|
|
return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
|
|
}
|
|
llvm_unreachable("unknown address space enum value");
|
|
});
|
|
}
|
|
|
|
void mlir::populateAMDGPUTypeAndAttributeConversions(
|
|
TypeConverter &typeConverter) {
|
|
typeConverter.addTypeAttributeConversion(
|
|
[](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
|
|
-> TypeConverter::AttributeConversionResult {
|
|
MLIRContext *ctx = as.getContext();
|
|
Type i64 = IntegerType::get(ctx, 64);
|
|
switch (as.getValue()) {
|
|
case amdgpu::AddressSpace::FatRawBuffer:
|
|
return IntegerAttr::get(i64, 7);
|
|
case amdgpu::AddressSpace::BufferRsrc:
|
|
return IntegerAttr::get(i64, 8);
|
|
case amdgpu::AddressSpace::FatStructuredBuffer:
|
|
return IntegerAttr::get(i64, 9);
|
|
}
|
|
return TypeConverter::AttributeConversionResult::abort();
|
|
});
|
|
typeConverter.addConversion([&](TDMBaseType type) -> Type {
|
|
Type i32 = IntegerType::get(type.getContext(), 32);
|
|
return typeConverter.convertType(VectorType::get(4, i32));
|
|
});
|
|
}
|
|
|
|
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns,
|
|
Chipset chipset) {
|
|
populateAMDGPUTypeAndAttributeConversions(converter);
|
|
patterns.add<
|
|
FatRawBufferCastLowering,
|
|
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
|
|
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
|
|
RawBufferOpLowering<RawBufferAtomicFaddOp,
|
|
ROCDL::RawPtrBufferAtomicFaddOp>,
|
|
RawBufferOpLowering<RawBufferAtomicFmaxOp,
|
|
ROCDL::RawPtrBufferAtomicFmaxOp>,
|
|
RawBufferOpLowering<RawBufferAtomicSmaxOp,
|
|
ROCDL::RawPtrBufferAtomicSmaxOp>,
|
|
RawBufferOpLowering<RawBufferAtomicUminOp,
|
|
ROCDL::RawPtrBufferAtomicUminOp>,
|
|
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
|
|
ROCDL::RawPtrBufferAtomicCmpSwap>,
|
|
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
|
|
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
|
|
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
|
|
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
|
|
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
|
|
GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
|
|
AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
|
|
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
|
|
AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
|
|
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
|
|
}
|