Krzysztof Drewniak 2ebd633f14 [mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering
Define operations that wrap the gfx940's new operations for converting
between f32 and registers containing packed sets of four 8-bit floats.

Define rocdl operations for the intrinsics and an AMDGPU dialect
wrapper around them (to account for the fact that MLIR distinguishes
the two float formats at the type level but that the LLVM IR does
not).

Define an ArithToAMDGPU pass, meant to run before conversion to LLVM,
that replaces relevant calls to arith.extf and arith.truncf with the
packed operations in the AMDGPU dialect. Note that the conversion
currently only handles scalars and vectors of rank <= 1, as we do not
have a usecase for multi-dimensional vector support right now.

Reviewed By: jsjodin

Differential Revision: https://reviews.llvm.org/D152457
2023-09-28 14:44:16 +00:00

858 lines
36 KiB
C++

//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::amdgpu;
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type llvmI32 = rewriter.getI32Type();
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI1, value);
}
namespace {
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = gpuOp.getLoc();
Value memref = adaptor.getMemref();
Value unconvertedMemref = gpuOp.getMemref();
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
if (chipset.majorVersion < 9)
return gpuOp.emitOpError("raw buffer ops require GCN or higher");
Value storeData = adaptor.getODSOperands(0)[0];
if (storeData == memref) // no write component to this op
storeData = Value();
Type wantedDataType;
if (storeData)
wantedDataType = storeData.getType();
else
wantedDataType = gpuOp.getODSResults(0)[0].getType();
Value atomicCmpData = Value();
// Operand index 1 of a load is the indices, trying to read them can crash.
if (storeData) {
Value maybeCmpData = adaptor.getODSOperands(1)[0];
if (maybeCmpData != memref)
atomicCmpData = maybeCmpData;
}
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
Type llvmI32 = this->typeConverter->convertType(i32);
Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
// so bitcast any floats to integers. On top of all this, cast bfloat
// (vectors) to i16 since the backend doesn't currently support bfloat on
// these operations.
Type llvmBufferValType = llvmWantedDataType;
if (wantedDataType.isBF16())
llvmBufferValType = rewriter.getI16Type();
if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
if (wantedVecType.getElementType().isBF16())
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
if (atomicCmpData) {
if (isa<VectorType>(wantedDataType))
return gpuOp.emitOpError("vector compare-and-swap does not exist");
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
rewriter.getIntegerType(floatType.getWidth()));
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
Twine(maxVectorOpWidth) + " bits, but we call for " +
Twine(totalBits) +
" bits. This should've been caught in validation");
if (elemBits < 32) {
if (totalBits > 32) {
if (totalBits % 32 != 0)
return gpuOp.emitOpError("Load or store of more than 32-bits that "
"doesn't fit into words. Can't happen\n");
llvmBufferValType = this->typeConverter->convertType(
VectorType::get(totalBits / 32, i32));
} else {
llvmBufferValType = this->typeConverter->convertType(
rewriter.getIntegerType(totalBits));
}
}
}
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForStore =
rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
}
}
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForCmp = rewriter.create<LLVM::BitcastOp>(
loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
}
}
// Construct buffer descriptor from memref, attributes
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
MemRefDescriptor memrefDescriptor(memref);
Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
// The stride value is always 0 for raw buffers. This also disables
// swizling.
Value stride = rewriter.createOrFold<LLVM::ConstantOp>(
loc, llvmI16, rewriter.getI16IntegerAttr(0));
Value numRecords;
if (memrefType.hasStaticShape()) {
numRecords = createI32Constant(
rewriter, loc,
static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
} else {
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
maxThisDim)
: maxThisDim;
}
numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
}
// Flag word:
// bits 0-11: dst sel, ignored by these intrinsics
// bits 12-14: data format (ignored, must be nonzero, 7=float)
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
// bit 19: In nested heap (0 here)
// bit 20: Behavior on unmap (0 means "return 0 / ignore")
// bits 21-22: Index stride for swizzles (N/A)
// bit 23: Add thread ID (0)
// bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
// bits 25-26: Reserved (0)
// bit 27: Buffer is non-volatile (CDNA only)
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
// none, 3 = either swizzles or testing against offset field) RDNA only
// bits 30-31: Type (must be 0)
uint32_t flags = (7 << 12) | (4 << 15);
if (chipset.majorVersion >= 10) {
flags |= (1 << 24);
uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
flags |= (oob << 28);
}
Value flagsConst = createI32Constant(rewriter, loc, flags);
Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
loc, rsrcType, ptr, stride, numRecords, flagsConst);
args.push_back(resource);
// Indexing (voffset)
Value voffset = createI32Constant(rewriter, loc, 0);
for (auto pair : llvm::enumerate(adaptor.getIndices())) {
size_t i = pair.index();
Value index = pair.value();
Value strideOp;
if (ShapedType::isDynamic(strides[i])) {
strideOp = rewriter.create<LLVM::MulOp>(
loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
} else {
strideOp =
createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
}
index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
}
if (adaptor.getIndexOffset()) {
int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
voffset =
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
: extraOffsetConst;
}
args.push_back(voffset);
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
if (ShapedType::isDynamic(offset))
sgprOffset = rewriter.create<LLVM::AddOp>(
loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
else if (offset > 0)
sgprOffset = rewriter.create<LLVM::AddOp>(
loc, sgprOffset, createI32Constant(rewriter, loc, offset));
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
// bits 1-2: SLC, DLC = 0 (similarly)
// bit 3: swizzled (0 for raw)
args.push_back(createI32Constant(rewriter, loc, 0));
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
rewriter.eraseOp(gpuOp);
}
return success();
}
};
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
using ConvertOpToLLVMPattern<LDSBarrierOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier";
const char *constraints = "";
rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
op,
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
return success();
}
};
} // namespace
/// If `input` is a vector of bytes, concatentate those bytes in little-endian
/// order to form a single integer of size 8 * [vector length]. This works
/// around a wart in the AMDGPU intrinsics where operations that logically take
/// vectors of bytes instead integers. Since we do not want to expose this
/// implementation detail to MLIR, we correct for it here.
///
/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
/// MFMA intrinsics pre-date the bfloat type.
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
if (!vectorType.getElementType().isInteger(8))
return input;
int64_t numBytes = vectorType.getNumElements();
Type destType = rewriter.getIntegerType(numBytes * 8);
Value result = rewriter.create<LLVM::ConstantOp>(
loc, destType, rewriter.getIntegerAttr(destType, 0));
for (int64_t i = 0; i < numBytes; ++i) {
Value idxConst = createI32Constant(rewriter, loc, i);
Value element =
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
Value shiftConst = rewriter.create<LLVM::ConstantOp>(
loc, destType, rewriter.getIntegerAttr(destType, i * 8));
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
}
return result;
}
return input;
}
/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
/// vector. We also need to convert bfloat inputs to i16 to account for the lack
/// of bfloat support in the WMMA intrinsics themselves.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = inputType.dyn_cast<VectorType>();
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
llvmInput = rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (!elemType.isInteger(8)) {
operands.push_back(llvmInput);
return;
}
int64_t numBytes = vectorType.getNumElements();
Type i32 = rewriter.getI32Type();
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmVectorType32bits, llvmInput);
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
if (elemType.isUnsignedInteger(8)) {
localIsUnsigned = true;
} else if (elemType.isSignedInteger(8)) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
operands.push_back(result);
}
/// Push the output operand. For many cases this is only pushing the output in
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
/// since the same numbers of VGPRs is used, we need to decide if to store the
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
/// be stored it in the upper part
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
Type inputType = output.getType();
auto vectorType = inputType.dyn_cast<VectorType>();
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
output = rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
} else if (elemType.isInteger(32)) {
operands.push_back(createI1Constant(rewriter, loc, clamp));
}
}
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
Chipset chipset) {
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
b = mfma.getBlocks();
Type sourceElem = mfma.getSourceA().getType();
if (auto sourceType = dyn_cast<VectorType>(sourceElem))
sourceElem = sourceType.getElementType();
Type destElem = mfma.getDestC().getType();
if (auto destType = dyn_cast<VectorType>(destElem))
destElem = destType.getElementType();
if (sourceElem.isF32() && destElem.isF32()) {
if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
}
if (m == 32 && n == 32 && k == 1 && b == 2)
return ROCDL::mfma_f32_32x32x1f32::getOperationName();
if (m == 16 && n == 16 && k == 1 && b == 4)
return ROCDL::mfma_f32_16x16x1f32::getOperationName();
if (m == 4 && n == 4 && k == 1 && b == 16)
return ROCDL::mfma_f32_4x4x1f32::getOperationName();
if (m == 32 && n == 32 && k == 2 && b == 1)
return ROCDL::mfma_f32_32x32x2f32::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f32_16x16x4f32::getOperationName();
}
if (sourceElem.isF16() && destElem.isF32()) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4f16::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4f16::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8f16::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
}
if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
}
if (sourceElem.isBF16() && destElem.isF32()) {
if (m == 32 && n == 32 && k == 2 && b == 2)
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
if (m == 16 && n == 16 && k == 2 && b == 4)
return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
if (m == 4 && n == 4 && k == 2 && b == 16)
return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
}
if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_i32_16x16x4i8::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_i32_4x4x4i8::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
}
if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 4)
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
chipset.minorVersion >= 0x40) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
chipset.minorVersion >= 0x40) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
return std::nullopt;
}
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
auto elemSourceType = sourceVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();
if (elemSourceType.isF16() && elemDestType.isF32()) {
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
} else if (elemSourceType.isBF16() && elemDestType.isF32()) {
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isF16() && elemDestType.isF16()) {
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
} else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
}
return std::nullopt;
}
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
Type intrinsicOutType = outType;
if (auto outVecType = dyn_cast<VectorType>(outType))
if (outVecType.getElementType().isBF16())
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
return op->emitOpError("MFMA only supported on gfx908+");
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
if (chipset.minorVersion < 0x40)
return op.emitOpError("negation unsupported on older than gfx840");
getBlgpField |=
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
}
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
{mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
if (chipset.majorVersion != 11)
return op->emitOpError("WMMA only supported on gfx11");
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(outType);
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
adaptor.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
adaptor.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
loweredOp.addOperands(operands);
Operation *lowered = rewriter.create(loweredOp);
rewriter.replaceOp(op, lowered->getResults());
return success();
}
};
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedStochRoundFp8Op op,
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v4i8 =
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
if (!sourceVecType) {
longVec = rewriter.create<LLVM::InsertElementOp>(
loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
longVec =
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (sourceElemType.isFloat8E5M2FNUZ()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
return success();
}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
if (resultElemType.isFloat8E5M2FNUZ())
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (resultElemType.isFloat8E4M3FNUZ())
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value source = adaptor.getSource();
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
if (resultElemType.isFloat8E5M2FNUZ())
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (resultElemType.isFloat8E4M3FNUZ())
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
void runOnOperation() override {
MLIRContext *ctx = &getContext();
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
if (failed(maybeChipset)) {
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
return signalPassFailure();
}
RewritePatternSet patterns(ctx);
LLVMTypeConverter converter(ctx);
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
converter.addConversion([](BFloat16Type t) -> Type {
return IntegerType::get(t.getContext(), 16);
});
converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
if (!t.getElementType().isBF16())
return std::nullopt;
return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
});
patterns.add<LDSBarrierOpLowering>(converter);
patterns
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
RawBufferOpLowering<RawBufferAtomicFaddOp,
ROCDL::RawPtrBufferAtomicFaddOp>,
RawBufferOpLowering<RawBufferAtomicFmaxOp,
ROCDL::RawPtrBufferAtomicFmaxOp>,
RawBufferOpLowering<RawBufferAtomicSmaxOp,
ROCDL::RawPtrBufferAtomicSmaxOp>,
RawBufferOpLowering<RawBufferAtomicUminOp,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
return std::make_unique<ConvertAMDGPUToROCDLPass>();
}