1950 lines
82 KiB
C++

//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "../LLVMCommon/MemRefDescriptor.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::amdgpu;
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
IntegerType i32 = rewriter.getI32Type();
// Force check that `val` is of int type.
auto valTy = cast<IntegerType>(val.getType());
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
: Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
ValueRange indices, ArrayRef<int64_t> strides) {
IntegerType i32 = rewriter.getI32Type();
Value index;
for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
if (stride != 1) { // Skip if stride is 1.
Value strideValue =
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
: LLVM::ConstantOp::create(rewriter, loc, i32, stride);
increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
: increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
/// Compute the contents of the `num_records` field for a given memref
/// descriptor - that is, the number of bytes that's one element past the
/// greatest possible valid index into the memref.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
MemRefType memrefType,
MemRefDescriptor &memrefDescriptor,
ArrayRef<int64_t> strides,
uint32_t elementByteWidth) {
if (memrefType.hasStaticShape() &&
!llvm::any_of(strides, ShapedType::isDynamic)) {
int64_t size = memrefType.getRank() == 0 ? 1 : 0;
ArrayRef<int64_t> shape = memrefType.getShape();
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
size = std::max(shape[i] * strides[i], size);
size = size * elementByteWidth;
assert(size < std::numeric_limits<uint32_t>::max() &&
"the memref buffer is too large");
return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
}
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value basePointer, Value numRecords,
bool boundsCheck, amdgpu::Chipset chipset,
Value cacheSwizzleStride = nullptr,
unsigned addressSpace = 8) {
// The stride value is generally 0. However, on MI-300 and onward, you can
// enable a cache swizzling mode by setting bit 14 of the stride field
// and setting that stride to a cache stride.
Type i16 = rewriter.getI16Type();
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
Value swizzleBit = LLVM::ConstantOp::create(
rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
/*isDisjoint=*/true);
} else {
stride = LLVM::ConstantOp::create(rewriter, loc, i16,
rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
// bits 0-11: dst sel, ignored by these intrinsics
// bits 12-14: data format (ignored, must be nonzero, 7=float)
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
// bit 19: In nested heap (0 here)
// bit 20: Behavior on unmap (0 means "return 0 / ignore")
// bits 21-22: Index stride for swizzles (N/A)
// bit 23: Add thread ID (0)
// bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
// bits 25-26: Reserved (0)
// bit 27: Buffer is non-volatile (CDNA only)
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
// none, 3 = either swizzles or testing against offset field) RDNA only
// bits 30-31: Type (must be 0)
uint32_t flags = (7 << 12) | (4 << 15);
if (chipset.majorVersion >= 10) {
flags |= (1 << 24);
uint32_t oob = boundsCheck ? 3 : 2;
flags |= (oob << 28);
}
Value flagsConst = createI32Constant(rewriter, loc, flags);
Type rsrcType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
loc, rsrcType, basePointer, stride, numRecords, flagsConst);
return resource;
}
namespace {
struct FatRawBufferCastLowering
: public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value memRef = adaptor.getSource();
Value unconvertedMemref = op.getSource();
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
MemRefDescriptor descriptor(memRef);
DataLayout dataLayout = DataLayout::closest(op);
int64_t elementByteWidth =
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
int64_t unusedOffset = 0;
SmallVector<int64_t, 5> strideVals;
if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
return op.emitOpError("Can't lower non-stride-offset memrefs");
Value numRecords = adaptor.getValidBytes();
if (!numRecords)
numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
strideVals, elementByteWidth);
Value basePointer =
adaptor.getResetOffset()
? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
memrefType)
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
Value sizes = hasSizes
? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
kSizePosInMemRefDescriptor)
: Value{};
Value strides =
hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
kStridePosInMemRefDescriptor)
: Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
kAlignedPtrPosInMemRefDescriptor);
result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
kOffsetPosInMemRefDescriptor);
if (hasSizes) {
result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
kSizePosInMemRefDescriptor);
result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = gpuOp.getLoc();
Value memref = adaptor.getMemref();
Value unconvertedMemref = gpuOp.getMemref();
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
if (chipset.majorVersion < 9)
return gpuOp.emitOpError("raw buffer ops require GCN or higher");
Value storeData = adaptor.getODSOperands(0)[0];
if (storeData == memref) // no write component to this op
storeData = Value();
Type wantedDataType;
if (storeData)
wantedDataType = storeData.getType();
else
wantedDataType = gpuOp.getODSResults(0)[0].getType();
Value atomicCmpData = Value();
// Operand index 1 of a load is the indices, trying to read them can crash.
if (storeData) {
Value maybeCmpData = adaptor.getODSOperands(1)[0];
if (maybeCmpData != memref)
atomicCmpData = maybeCmpData;
}
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
int64_t elementByteWidth =
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
// so bitcast any floats to integers.
Type llvmBufferValType = llvmWantedDataType;
if (atomicCmpData) {
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
rewriter.getIntegerType(floatType.getWidth()));
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t vecLen = dataVector.getNumElements();
uint32_t elemBits =
dataLayout.getTypeSizeInBits(dataVector.getElementType());
uint32_t totalBits = elemBits * vecLen;
bool usePackedFp16 =
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
Twine(maxVectorOpWidth) + " bits, but we call for " +
Twine(totalBits) +
" bits. This should've been caught in validation");
if (!usePackedFp16 && elemBits < 32) {
if (totalBits > 32) {
if (totalBits % 32 != 0)
return gpuOp.emitOpError("Load or store of more than 32-bits that "
"doesn't fit into words. Can't happen\n");
llvmBufferValType = this->typeConverter->convertType(
VectorType::get(totalBits / 32, i32));
} else {
llvmBufferValType = this->typeConverter->convertType(
rewriter.getIntegerType(totalBits));
}
}
}
if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
// Buffer intrinsics doesn't support 1-element vectors, cast them to
// scalars.
if (vecType.getNumElements() == 1)
llvmBufferValType = vecType.getElementType();
}
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForStore = LLVM::BitcastOp::create(
rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
}
}
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForCmp = LLVM::BitcastOp::create(
rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
}
}
// Construct buffer descriptor from memref, attributes
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
MemRefDescriptor memrefDescriptor(memref);
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
Value numRecords = getNumRecords(
rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (voffset)
Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
adaptor.getIndices(), strides);
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
extraOffsetConst)
: extraOffsetConst;
}
voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
// bits 1-2: SLC, DLC = 0 (similarly)
// bit 3: swizzled (0 for raw)
args.push_back(createI32Constant(rewriter, loc, 0));
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
rewriter.eraseOp(gpuOp);
}
return success();
}
};
// TODO: AMDGPU backend already have all this bitpacking logic, we should move
// it to some common place.
/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
/// Vmcnt = Waitcnt[15:10] (gfx11)
/// Expcnt = Waitcnt[6:4] (pre-gfx11)
/// Expcnt = Waitcnt[2:0] (gfx11)
/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
/// Lgkmcnt = Waitcnt[13:8] (gfx10)
/// Lgkmcnt = Waitcnt[9:4] (gfx11)
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
unsigned expcnt, unsigned lgkmcnt) {
if (chipset.majorVersion < 9) {
vmcnt = std::min(15u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(15u, lgkmcnt);
return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
}
if (chipset.majorVersion == 9) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(15u, lgkmcnt);
unsigned lowBits = vmcnt & 0xF;
unsigned highBits = (vmcnt >> 4) << 14;
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
return lowBits | highBits | otherCnts;
}
if (chipset.majorVersion == 10) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(63u, lgkmcnt);
unsigned lowBits = vmcnt & 0xF;
unsigned highBits = (vmcnt >> 4) << 14;
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
return lowBits | highBits | otherCnts;
}
if (chipset.majorVersion == 11) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(63u, lgkmcnt);
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
}
return failure();
}
struct MemoryCounterWaitOpLowering
: public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset.majorVersion >= 12) {
Location loc = op.getLoc();
if (std::optional<int> ds = adaptor.getDs())
ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
if (std::optional<int> load = adaptor.getLoad())
ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
if (std::optional<int> store = adaptor.getStore())
ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
if (std::optional<int> exp = adaptor.getExp())
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
rewriter.eraseOp(op);
return success();
}
auto getVal = [](Attribute attr) -> unsigned {
if (attr)
return cast<IntegerAttr>(attr).getInt();
// This value will be clamped to the maximum value for the chipset.
return 1024;
};
unsigned ds = getVal(adaptor.getDsAttr());
unsigned exp = getVal(adaptor.getExpAttr());
unsigned vmcnt = 1024;
Attribute load = adaptor.getLoadAttr();
Attribute store = adaptor.getStoreAttr();
if (load && store) {
vmcnt = getVal(load) + getVal(store);
} else if (load) {
vmcnt = getVal(load);
} else if (store) {
vmcnt = getVal(store);
}
FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
if (failed(waitcnt))
return op.emitOpError("unsupported chipset");
rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
return success();
}
};
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
if (requiresInlineAsm) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
const char *constraints = "";
rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
op,
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
/*is_align_stack=*/false, LLVM::TailCallKind::None,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
return success();
}
if (chipset.majorVersion < 12) {
constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
// Left in place in case someone disables the inline ASM path or future
// chipsets use the same bit pattern.
constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
int32_t ldsOnlyBits;
if (chipset.majorVersion == 11)
ldsOnlyBits = ldsOnlyBitsGfx11;
else if (chipset.majorVersion == 10)
ldsOnlyBits = ldsOnlyBitsGfx10;
else if (chipset.majorVersion <= 9)
ldsOnlyBits = ldsOnlyBitsGfx6789;
else
return op.emitOpError(
"don't know how to lower this for chipset major version")
<< chipset.majorVersion;
Location loc = op->getLoc();
ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
ROCDL::WaitDscntOp::create(rewriter, loc, 0);
ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
return success();
}
};
struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
(uint32_t)op.getOpts());
return success();
}
};
} // namespace
/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
/// and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
/// allows bf16. Newer MFMAs support bf16 types on operand, check
/// IntrinsicsAMDGPU.td file for reference.
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
/// instead, which is what the f8f6f4 intrinsics use.
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
/// integer.
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input,
bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
return LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
return LLVM::BitcastOp::create(
rewriter, loc,
rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
return LLVM::BitcastOp::create(
rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
input);
}
}
return input;
}
/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
/// We also need to convert bfloat inputs to i16 to account for the bfloat
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
Value mlirInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
operands.push_back(llvmInput);
return;
}
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
llvmInput = LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
}
// We need to check the type of the input before conversion to properly test
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
// fp8/int8 information is lost during the conversion process.
auto mlirInputType = cast<VectorType>(mlirInput.getType());
bool isInputInteger = mlirInputType.getElementType().isInteger();
if (isInputInteger) {
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
if (elemType.isUnsignedInteger()) {
localIsUnsigned = true;
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
}
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
Type i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmIntrinsicInType, llvmInput);
// The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
/// Push the output operand. For many cases this is only pushing the output in
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
/// since the same numbers of VGPRs is used, we need to decide if to store the
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
/// as the instructions have been changed to return fewer registers instead.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
output = LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
} else if (elemType.isInteger(32)) {
operands.push_back(createI1Constant(rewriter, loc, clamp));
}
}
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
Chipset chipset) {
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
b = mfma.getBlocks();
Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
if (sourceElem.isF32() && destElem.isF32()) {
if (mfma.getReducePrecision() && chipset >= kGfx942) {
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
}
if (m == 32 && n == 32 && k == 1 && b == 2)
return ROCDL::mfma_f32_32x32x1f32::getOperationName();
if (m == 16 && n == 16 && k == 1 && b == 4)
return ROCDL::mfma_f32_16x16x1f32::getOperationName();
if (m == 4 && n == 4 && k == 1 && b == 16)
return ROCDL::mfma_f32_4x4x1f32::getOperationName();
if (m == 32 && n == 32 && k == 2 && b == 1)
return ROCDL::mfma_f32_32x32x2f32::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f32_16x16x4f32::getOperationName();
}
if (sourceElem.isF16() && destElem.isF32()) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 16 && b == 1)
return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1)
return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
}
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4f16::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4f16::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8f16::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
}
if (sourceElem.isBF16() && destElem.isF32()) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 16 && b == 1)
return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1)
return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
}
if (chipset >= kGfx90a) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
}
if (m == 32 && n == 32 && k == 2 && b == 2)
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
if (m == 16 && n == 16 && k == 2 && b == 4)
return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
if (m == 4 && n == 4 && k == 2 && b == 16)
return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
}
if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 32 && b == 1)
return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
if (m == 16 && n == 16 && k == 64 && b == 1)
return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
}
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_i32_16x16x4i8::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_i32_4x4x4i8::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
}
if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 4)
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
return std::nullopt;
}
static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
.Case([](Float6E2M3FNType) { return 2u; })
.Case([](Float6E3M2FNType) { return 3u; })
.Case([](Float4E2M1FNType) { return 4u; })
.Default([](Type) { return std::nullopt; });
}
/// If there is a scaled MFMA instruction for the input element types `aType`
/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
/// blocks) on the given `chipset`, return a tuple consisting of the
/// OperationName of the intrinsic and the type codes that need to be passed to
/// that intrinsic. Note that this is also used to implement some un-scaled
/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
/// MFMA with a scale of 0.
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
aType = getElementTypeOrSelf(aType);
bType = getElementTypeOrSelf(bType);
destType = getElementTypeOrSelf(destType);
if (chipset < kGfx950)
return std::nullopt;
if (!isa<Float32Type>(destType))
return std::nullopt;
std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
if (m == 32 && n == 32 && k == 64 && b == 1)
return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
*aTypeCode, *bTypeCode};
if (m == 16 && n == 16 && k == 128 && b == 1)
return std::tuple{
ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
*bTypeCode};
return std::nullopt;
}
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
mfma.getBlocks(), chipset);
}
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
smfma.getSourceB().getType(),
smfma.getDestC().getType(), smfma.getM(),
smfma.getN(), smfma.getK(), 1u, chipset);
}
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
auto elemSourceType = sourceVectorType.getElementType();
auto elemBSourceType = sourceBVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
if (chipset.majorVersion == 11) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
}
if (chipset.majorVersion >= 12) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
bool isWave64 = destVectorType.getNumElements() == 4;
// This is the ambiguous case. 8 inputs to the wave64 version means that
// we want the 16x16x32 version, but for wave32 they mean the short form.
bool has8Inputs = sourceVectorType.getNumElements() == 8;
if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
}
}
return std::nullopt;
}
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
Type intrinsicOutType = outType;
if (auto outVecType = dyn_cast<VectorType>(outType))
if (outVecType.getElementType().isBF16())
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
if (chipset.majorVersion != 9 || chipset < kGfx908)
return op->emitOpError("MFMA only supported on gfx908+");
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
if (chipset < kGfx942)
return op.emitOpError("negation unsupported on older than gfx942");
getBlgpField |=
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
}
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
bool isScaled =
!maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
if (isScaled &&
(adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
return op.emitOpError(
"non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
"be scaled as those fields are used for type information");
}
StringRef intrinsicName =
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
// Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
// allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
bool allowBf16 = [&]() {
if (chipset < kGfx950)
return false;
if (isScaled)
return true;
return intrinsicName.contains("16x16x32.bf16") ||
intrinsicName.contains("32x32x16.bf16");
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands({convertMFMAVectorOperand(
rewriter, loc, adaptor.getSourceA(), allowBf16),
convertMFMAVectorOperand(
rewriter, loc, adaptor.getSourceB(), allowBf16),
adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scale A byte=*/zero, /*scale A=*/zero,
/*scale B byte=*/zero, /*scale B=*/zero});
} else {
loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
if (chipset.majorVersion != 9 || chipset < kGfx950)
return op->emitOpError("scaled MFMA only supported on gfx908+");
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeScaledIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching scaled MFMA size on given chipset");
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
Value scalesIdxB =
createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
loweredOp.addOperands(
{createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scales idx A=*/scalesIdxA,
/*scales A*/
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
/*scales idx B=*/scalesIdxB,
/*scales B*/
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
// need to bitcast bfloats to i16 and then bitcast them back.
VectorType rawOutType = outType;
if (outType.getElementType().isBF16())
rawOutType = outType.clone(rewriter.getI16Type());
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(rawOutType);
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
loweredOp.addOperands(operands);
Operation *lowered = rewriter.create(loweredOp);
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
}
};
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset != kGfx950)
return op.emitOpError("Non-gfx950 chipset not supported");
Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
// Elements in subbyte memrefs are stored non-contiguously,
// reject if source is sub-byte memref. Use emulated memrefs instead.
size_t srcElementSize =
srcMemRefType.getElementType().getIntOrFloatBitWidth();
if (srcElementSize < 8)
return op.emitOpError("Expect source memref to have at least 8 bits "
"element size, got ")
<< srcElementSize;
auto resultType = cast<VectorType>(op.getResult().getType());
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
size_t numElements = resultType.getNumElements();
size_t elementTypeSize =
resultType.getElementType().getIntOrFloatBitWidth();
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
// the element size is smaller than 16 bits.
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
rewriter.getIntegerType(32));
Type llvmResultType = typeConverter->convertType(resultType);
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 16: {
assert(numElements == 4);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
srcPtr);
break;
}
default:
return op.emitOpError("Unsupported element size for transpose load");
}
return success();
}
};
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
// `global_load_lds` instructions.
Type transferType = op.getTransferType();
int loadWidth = [&]() -> int {
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
return (transferVectorType.getNumElements() *
transferVectorType.getElementTypeBitWidth()) /
8;
}
return transferType.getIntOrFloatBitWidth() / 8;
}();
// Currently only 1, 2, 4, 12 and 16 byte loads are supported.
if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
return op.emitOpError("chipset unsupported element size");
if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
return op.emitOpError("Gather to LDS instructions with 12-byte and "
"16-byte load widths are only supported on gfx950");
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
Value dstPtr =
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()));
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
/*offset=*/rewriter.getI32IntegerAttr(0),
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});
return success();
}
};
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedStochRoundFp8Op op,
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct ScaledExtPackedOpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedScaledTruncOpLowering final
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v4i8 =
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
longVec = LLVM::InsertElementOp::create(
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
op.getIndex());
}
} else {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
op.getIndex());
}
}
return success();
}
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Value source = adaptor.getSource();
Value scale = adaptor.getScale();
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();
VectorType destVecType = cast<VectorType>(op.getResult().getType());
Type destElemType = destVecType.getElementType();
VectorType packedVecType;
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
} else {
llvm_unreachable("invalid element type for scaled ext");
}
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
longVec = LLVM::InsertElementOp::create(
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
return success();
}
LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v2i16 = getTypeConverter()->convertType(
VectorType::get(2, rewriter.getI16Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();
Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
Value source = adaptor.getSource();
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
source = LLVM::ZeroOp::create(rewriter, loc, v2);
source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value source = adaptor.getSource();
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
// operation into the corresponding ROCDL instructions.
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the source operand to the corresponding LLVM type
Location loc = DppOp.getLoc();
Value src = adaptor.getSrc();
Value old = adaptor.getOld();
Type srcType = src.getType();
Type oldType = old.getType();
Type llvmType = nullptr;
if (srcType.getIntOrFloatBitWidth() < 32) {
llvmType = rewriter.getI32Type();
} else if (isa<FloatType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getF32Type()
: rewriter.getF64Type();
} else if (isa<IntegerType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getI32Type()
: rewriter.getI64Type();
}
auto llvmSrcIntType = typeConverter->convertType(
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
// If the source type is less of 32, use bitcast to convert it to i32.
auto convertOperand = [&](Value operand, Type operandType) {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
operand =
LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
createI32Constant(rewriter, loc, 0));
operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
src = convertOperand(src, srcType);
old = convertOperand(old, oldType);
// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
enum DppCtrl : unsigned {
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
ROW_ROR0 = 0x120,
WAVE_SHL1 = 0x130,
WAVE_ROL1 = 0x134,
WAVE_SHR1 = 0x138,
WAVE_ROR1 = 0x13C,
ROW_MIRROR = 0x140,
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143,
};
auto kind = DppOp.getKind();
auto permArgument = DppOp.getPermArgument();
uint32_t DppCtrl = 0;
switch (kind) {
case DPPPerm::quad_perm:
if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
int32_t i = 0;
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
uint32_t num = elem.getInt();
DppCtrl |= num << (i * 2);
i++;
}
}
break;
case DPPPerm::row_shl:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
}
break;
case DPPPerm::row_shr:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
}
break;
case DPPPerm::row_ror:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
}
break;
case DPPPerm::wave_shl:
DppCtrl = DppCtrl::WAVE_SHL1;
break;
case DPPPerm::wave_shr:
DppCtrl = DppCtrl::WAVE_SHR1;
break;
case DPPPerm::wave_rol:
DppCtrl = DppCtrl::WAVE_ROL1;
break;
case DPPPerm::wave_ror:
DppCtrl = DppCtrl::WAVE_ROR1;
break;
case DPPPerm::row_mirror:
DppCtrl = DppCtrl::ROW_MIRROR;
break;
case DPPPerm::row_half_mirror:
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
break;
case DPPPerm::row_bcast_15:
DppCtrl = DppCtrl::BCAST15;
break;
case DPPPerm::row_bcast_31:
DppCtrl = DppCtrl::BCAST31;
break;
}
// Check for row_mask, bank_mask, bound_ctrl if they exist and create
// constants
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
auto dppMovOp =
ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
// We are replacing the AMDGPU_DPPOp instruction with the new
// ROCDL_DPPMovOp instruction
rewriter.replaceOp(DppOp, ValueRange(result));
return success();
}
};
struct AMDGPUSwizzleBitModeLowering
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);
unsigned andMask = op.getAndMask();
unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();
// bit 15 is 0 for the BitMode swizzle.
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
Value maskValue = createI32Constant(rewriter, loc, mask);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *ctx = &getContext();
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
if (failed(maybeChipset)) {
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
return signalPassFailure();
}
RewritePatternSet patterns(ctx);
LLVMTypeConverter converter(ctx);
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateAMDGPUMemorySpaceAttributeConversions(
TypeConverter &typeConverter) {
typeConverter.addTypeAttributeConversion(
[](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
-> TypeConverter::AttributeConversionResult {
MLIRContext *ctx = as.getContext();
Type i64 = IntegerType::get(ctx, 64);
switch (as.getValue()) {
case amdgpu::AddressSpace::FatRawBuffer:
return IntegerAttr::get(i64, 7);
case amdgpu::AddressSpace::BufferRsrc:
return IntegerAttr::get(i64, 8);
case amdgpu::AddressSpace::FatStructuredBuffer:
return IntegerAttr::get(i64, 9);
}
return TypeConverter::AttributeConversionResult::abort();
});
}
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
populateAMDGPUMemorySpaceAttributeConversions(converter);
patterns
.add<FatRawBufferCastLowering,
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
RawBufferOpLowering<RawBufferAtomicFaddOp,
ROCDL::RawPtrBufferAtomicFaddOp>,
RawBufferOpLowering<RawBufferAtomicFmaxOp,
ROCDL::RawPtrBufferAtomicFmaxOp>,
RawBufferOpLowering<RawBufferAtomicSmaxOp,
ROCDL::RawPtrBufferAtomicSmaxOp>,
RawBufferOpLowering<RawBufferAtomicUminOp,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}