//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "../LLVMCommon/MemRefDescriptor.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::amdgpu; // Define commonly used chipsets versions for convenience. constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); constexpr Chipset kGfx1250 = Chipset(12, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val) { IntegerType i32 = rewriter.getI32Type(); // Force check that `val` is of int type. auto valTy = cast(val.getType()); if (i32 == valTy) return val; return valTy.getWidth() > 32 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value); } /// Convert an unsigned number `val` to i64. static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val) { IntegerType i64 = rewriter.getI64Type(); // Force check that `val` is of int type. auto valTy = cast(val.getType()); if (i64 == valTy) return val; return valTy.getWidth() > 64 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val)) : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val)); } static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value) { return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value); } /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef strides) { IntegerType i32 = rewriter.getI32Type(); Value index; for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) { if (stride != 1) { // Skip if stride is 1. Value strideValue = ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) : LLVM::ConstantOp::create(rewriter, loc, i32, stride); increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } /// Compute the contents of the `num_records` field for a given memref /// descriptor - that is, the number of bytes that's one element past the /// greatest possible valid index into the memref. static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck) { if (chipset >= kGfx1250 && !boundsCheck) { constexpr int64_t first45bits = (1ll << 45) - 1; return createI64Constant(rewriter, loc, first45bits); } if (memrefType.hasStaticShape() && !llvm::any_of(strides, ShapedType::isDynamic)) { int64_t size = memrefType.getRank() == 0 ? 1 : 0; ArrayRef shape = memrefType.getShape(); for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) size = std::max(shape[i] * strides[i], size); size = size * elementByteWidth; return createI64Constant(rewriter, loc, size); } Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex); Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth); return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride = nullptr, unsigned addressSpace = 8) { // The stride value is generally 0. However, on MI-300 and onward, you can // enable a cache swizzling mode by setting bit 14 of the stride field // and setting that stride to a cache stride. Type i16 = rewriter.getI16Type(); Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); Value swizzleBit = LLVM::ConstantOp::create( rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, /*isDisjoint=*/true); } else { stride = LLVM::ConstantOp::create(rewriter, loc, i16, rewriter.getI16IntegerAttr(0)); } uint32_t flags = 0; if (chipset >= kGfx1250) { // Flag word: // bit 0: swizzle // bit 1: 0 means (total_offset + payload > numRecords) // 1 means ((total_offset + payload >) numRecords) || ((offset + // payload) > stride) only applied when swizzle_enable = 0. keep at // zero. // whether oob is done depends on numRecords. // bits 2-3: Type (must be 0) } else { // Get the number of elements. // Flag word: // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) // bits 15-18: data format (ignored, must be nonzero, 4=32bit) // bit 19: In nested heap (0 here) // bit 20: Behavior on unmap (0 means "return 0 / ignore") // bits 21-22: Index stride for swizzles (N/A) // bit 23: Add thread ID (0) // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) // bits 25-26: Reserved (0) // bit 27: Buffer is non-volatile (CDNA only) // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = // none, 3 = either swizzles or testing against offset field) RDNA only // bits 30-31: Type (must be 0) flags |= (7 << 12) | (4 << 15); if (chipset.majorVersion >= 10) { flags |= (1 << 24); uint32_t oob = boundsCheck ? 3 : 2; flags |= (oob << 28); } } Value flagsConst = createI32Constant(rewriter, loc, flags); Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); Value resource = rewriter.createOrFold( loc, rsrcType, basePointer, stride, numRecords, flagsConst); return resource; } namespace { struct FatRawBufferCastLowering : public ConvertOpToLLVMPattern { FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value memRef = adaptor.getSource(); Value unconvertedMemref = op.getSource(); MemRefType memrefType = cast(unconvertedMemref.getType()); MemRefDescriptor descriptor(memRef); DataLayout dataLayout = DataLayout::closest(op); int64_t elementByteWidth = dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; int64_t unusedOffset = 0; SmallVector strideVals; if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset))) return op.emitOpError("Can't lower non-stride-offset memrefs"); Value numRecords = adaptor.getValidBytes(); if (!numRecords) numRecords = getNumRecords(rewriter, loc, memrefType, descriptor, strideVals, elementByteWidth, chipset, adaptor.getBoundsCheck()); Value basePointer = adaptor.getResetOffset() ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType) : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. Value sizes = hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kSizePosInMemRefDescriptor) : Value{}; Value strides = hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kStridePosInMemRefDescriptor) : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); SmallVector pos{kAllocatedPtrPosInMemRefDescriptor}; result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, kOffsetPosInMemRefDescriptor); if (hasSizes) { result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, kSizePosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); } }; /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; static constexpr uint32_t maxVectorOpWidth = 128; LogicalResult matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); MemRefType memrefType = cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("raw buffer ops require GCN or higher"); Value storeData = adaptor.getODSOperands(0)[0]; if (storeData == memref) // no write component to this op storeData = Value(); Type wantedDataType; if (storeData) wantedDataType = storeData.getType(); else wantedDataType = gpuOp.getODSResults(0)[0].getType(); Value atomicCmpData = Value(); // Operand index 1 of a load is the indices, trying to read them can crash. if (storeData) { Value maybeCmpData = adaptor.getODSOperands(1)[0]; if (maybeCmpData != memref) atomicCmpData = maybeCmpData; } Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); // Get the type size in bytes. DataLayout dataLayout = DataLayout::closest(gpuOp); int64_t elementByteWidth = dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); // If we want to load a vector with total size <= 32 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, // so bitcast any floats to integers. Type llvmBufferValType = llvmWantedDataType; if (atomicCmpData) { if (auto floatType = dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } if (auto dataVector = dyn_cast(wantedDataType)) { uint32_t vecLen = dataVector.getNumElements(); uint32_t elemBits = dataLayout.getTypeSizeInBits(dataVector.getElementType()); uint32_t totalBits = elemBits * vecLen; bool usePackedFp16 = isa_and_present(*gpuOp) && vecLen == 2; if (totalBits > maxVectorOpWidth) return gpuOp.emitOpError( "Total width of loads or stores must be no more than " + Twine(maxVectorOpWidth) + " bits, but we call for " + Twine(totalBits) + " bits. This should've been caught in validation"); if (!usePackedFp16 && elemBits < 32) { if (totalBits > 32) { if (totalBits % 32 != 0) return gpuOp.emitOpError("Load or store of more than 32-bits that " "doesn't fit into words. Can't happen\n"); llvmBufferValType = this->typeConverter->convertType( VectorType::get(totalBits / 32, i32)); } else { llvmBufferValType = this->typeConverter->convertType( rewriter.getIntegerType(totalBits)); } } } if (auto vecType = dyn_cast(llvmBufferValType)) { // Buffer intrinsics doesn't support 1-element vectors, cast them to // scalars. if (vecType.getNumElements() == 1) llvmBufferValType = vecType.getElementType(); } SmallVector args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { Value castForStore = LLVM::BitcastOp::create( rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); } } if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { Value castForCmp = LLVM::BitcastOp::create( rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); } } // Construct buffer descriptor from memref, attributes int64_t offset = 0; SmallVector strides; if (failed(memrefType.getStridesAndOffset(strides, offset))) return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); MemRefDescriptor memrefDescriptor(memref); Value ptr = memrefDescriptor.bufferPtr( rewriter, loc, *this->getTypeConverter(), memrefType); Value numRecords = getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth, chipset, adaptor.getBoundsCheck()); Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords, adaptor.getBoundsCheck(), chipset); args.push_back(resource); // Indexing (voffset) Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, adaptor.getIndices(), strides); if (std::optional indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, extraOffsetConst) : extraOffsetConst; } voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) // bits 1-2: SLC, DLC = 0 (similarly) // bit 3: swizzled (0 for raw) args.push_back(createI32Constant(rewriter, loc, 0)); llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, replacement); } rewriter.replaceOp(gpuOp, replacement); } else { rewriter.eraseOp(gpuOp); } return success(); } }; // TODO: AMDGPU backend already have all this bitpacking logic, we should move // it to some common place. /// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: /// Vmcnt = Waitcnt[3:0] (pre-gfx9) /// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) /// Vmcnt = Waitcnt[15:10] (gfx11) /// Expcnt = Waitcnt[6:4] (pre-gfx11) /// Expcnt = Waitcnt[2:0] (gfx11) /// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) /// Lgkmcnt = Waitcnt[13:8] (gfx10) /// Lgkmcnt = Waitcnt[9:4] (gfx11) static FailureOr encodeWaitcnt(Chipset chipset, unsigned vmcnt, unsigned expcnt, unsigned lgkmcnt) { if (chipset.majorVersion < 9) { vmcnt = std::min(15u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(15u, lgkmcnt); return vmcnt | (expcnt << 4) | (lgkmcnt << 8); } if (chipset.majorVersion == 9) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(15u, lgkmcnt); unsigned lowBits = vmcnt & 0xF; unsigned highBits = (vmcnt >> 4) << 14; unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); return lowBits | highBits | otherCnts; } if (chipset.majorVersion == 10) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(63u, lgkmcnt); unsigned lowBits = vmcnt & 0xF; unsigned highBits = (vmcnt >> 4) << 14; unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); return lowBits | highBits | otherCnts; } if (chipset.majorVersion == 11) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(63u, lgkmcnt); return (vmcnt << 10) | expcnt | (lgkmcnt << 4); } return failure(); } struct MemoryCounterWaitOpLowering : public ConvertOpToLLVMPattern { MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional ds = adaptor.getDs()) ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional load = adaptor.getLoad()) ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional store = adaptor.getStore()) ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional exp = adaptor.getExp()) ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); if (std::optional tensor = adaptor.getTensor()) ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor); rewriter.eraseOp(op); return success(); } if (adaptor.getTensor()) return op.emitOpError("unsupported chipset"); auto getVal = [](Attribute attr) -> unsigned { if (attr) return cast(attr).getInt(); // This value will be clamped to the maximum value for the chipset. return 1024; }; unsigned ds = getVal(adaptor.getDsAttr()); unsigned exp = getVal(adaptor.getExpAttr()); unsigned vmcnt = 1024; Attribute load = adaptor.getLoadAttr(); Attribute store = adaptor.getStoreAttr(); if (load && store) { vmcnt = getVal(load) + getVal(store); } else if (load) { vmcnt = getVal(load); } else if (store) { vmcnt = getVal(store); } FailureOr waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); if (failed(waitcnt)) return op.emitOpError("unsupported chipset"); rewriter.replaceOpWithNewOp(op, *waitcnt); return success(); } }; struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // This ensures that waits on global memory aren't introduced on // chips that don't have the BackOffBarrier feature enabled in LLVM. bool requiresInlineAsm = chipset < kGfx90a; Attribute mmra = rewriter.getAttr("amdgpu-synchronize-as", "local"); // Note: while there *is* a workgroup-one-as scope, this, when combined with // the MMRA, will lead to the fence having no effect. This is because the // codepaths for an atomic load or store will observe that a // one-address-space atomic to LDS requires no synchronization because // operations on LDS are totally ordered with respect to each other, and so // will not emit the correct waitcnt operations that these fences are // intended to produce. Therefore, we use a broader type of fence and rely // on the MMRA to relax it to the semantics we want. StringRef scope = "workgroup"; auto relFence = LLVM::FenceOp::create(rewriter, loc, LLVM::AtomicOrdering::release, scope); relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra); if (requiresInlineAsm) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier"; const char *constraints = ""; LLVM::InlineAsmOp::create( rewriter, loc, /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, /*is_align_stack=*/false, LLVM::TailCallKind::None, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); } else if (chipset.majorVersion < 12) { ROCDL::SBarrierOp::create(rewriter, loc); } else { ROCDL::BarrierSignalOp::create(rewriter, loc, -1); ROCDL::BarrierWaitOp::create(rewriter, loc, -1); } auto acqFence = LLVM::FenceOp::create(rewriter, loc, LLVM::AtomicOrdering::acquire, scope); acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra); rewriter.replaceOp(op, acqFence); return success(); } }; struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, (uint32_t)op.getOpts()); return success(); } }; } // namespace /// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format /// expected by scaled matrix multiply intrinsics (MFMA/WMMA). /// /// Specifically: /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic /// allows bf16. Newer MFMAs support bf16 types on operand, check /// IntrinsicsAMDGPU.td file for reference. /// 2. If instead we have a more than 64-bit quantity, use a /// instead, which is what the f8f6f4 intrinsics use. /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit /// integer. /// /// Note that the type of `input` has already been LLVM type converted: /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16 = true) { Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) return LLVM::BitcastOp::create( rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) return LLVM::BitcastOp::create( rewriter, loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); return LLVM::BitcastOp::create( rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input); } } return input; } /// Converts sparse MFMA (smfmac) operands to the expected ROCDL types. static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16 = true) { Type inputType = input.getType(); auto vectorType = cast(inputType); // bf16 -> i16 when not allowed (pre-gfx950). if (vectorType.getElementType().isBF16() && !allowBf16) return LLVM::BitcastOp::create( rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); // i8/fp8 vectors -> vector. if (isa(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); return LLVM::BitcastOp::create( rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input); } return input; } /// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR /// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention. /// /// Specifically: /// 1. If `input` is a i8 value, zero extend it to i32 /// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32 /// /// Note that the type of `input` has already been LLVM type converted: /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input) { return TypeSwitch(input.getType()) .Case([&](IntegerType) { // Handle scalar i8: zero extend to i32. return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input); }) .Case([&](VectorType vectorType) { // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64. int64_t numElements = vectorType.getNumElements(); assert((numElements == 4 || numElements == 8) && "scale operand must be a vector of length 4 or 8"); IntegerType outputType = (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type(); return LLVM::BitcastOp::create(rewriter, loc, outputType, input); }) .DefaultUnreachable("unexpected input type for scale operand"); } /// Maps f8 scale element types to WMMA scale format codes. static std::optional getWmmaScaleFormat(Type elemType) { return TypeSwitch>(elemType) .Case([](Float8E8M0FNUType) { return 0; }) .Case([](Float8E4M3FNType) { return 2; }) .Default(std::nullopt); } /// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions /// and scale block size (16 or 32). static std::optional getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) { if (m == 16 && n == 16 && k == 128) return isScale16 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName() : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName(); if (m == 32 && n == 16 && k == 128) return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName() : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName(); return std::nullopt; } /// Push an input operand. If it is a float type, nothing to do. If it is /// an integer type, then we need to also push its signdness (1 for signed, 0 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). /// We also need to convert bfloat inputs to i16 to account for the bfloat /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). static void wmmaPushInputOperand( ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl &operands, SmallVectorImpl &attrs, StringRef attrName) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); if (!vectorType) { operands.push_back(llvmInput); return; } Type elemType = vectorType.getElementType(); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; } // We need to check the type of the input before conversion to properly test // for int8. This is because, in LLVM, fp8 type is converted to int8, so the // fp8/int8 information is lost during the conversion process. auto mlirInputType = cast(mlirInput.getType()); bool isInputInteger = mlirInputType.getElementType().isInteger(); if (isInputInteger) { // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag bool localIsUnsigned = isUnsigned; if (elemType.isUnsignedInteger()) { localIsUnsigned = true; } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } attrs.push_back( NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned))); } int64_t numBits = vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); Type i32 = rewriter.getI32Type(); Type intrinsicInType = numBits <= 32 ? (Type)rewriter.getIntegerType(numBits) : (Type)VectorType::get(numBits / 32, i32); auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType); Value castInput = rewriter.createOrFold( loc, llvmIntrinsicInType, llvmInput); // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } /// Push the output operand. For many cases this is only pushing the output in /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, /// since the same numbers of VGPRs is used, we need to decide if to store the /// result in the upper 16 bits of the VGPRs or in the lower part. To store the /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will /// be stored it in the upper part. The subwordOffset must not be set for gfx12, /// as the instructions have been changed to return fewer registers instead. static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl &operands, SmallVectorImpl &attrs) { Type inputType = output.getType(); auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { attrs.push_back( NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset))); } else if (elemType.isInteger(32)) { attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp))); } } /// Return true if `type` is the E5M2 variant of an 8-bit float that is /// supported by the `_bf8` instructions on the given `chipset`. static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { return (chipset == kGfx942 && isa(type)) || (hasOcpFp8(chipset) && isa(type)); } /// Return true if `type` is the E4M3FN variant of an 8-bit float that is /// supported by the `_fp8` instructions on the given `chipset`. static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { return (chipset == kGfx942 && isa(type)) || (hasOcpFp8(chipset) && isa(type)); } /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) { uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType()); Type destElem = getElementTypeOrSelf(mfma.getDestC().getType()); if (sourceElem.isF32() && destElem.isF32()) { if (mfma.getReducePrecision() && chipset >= kGfx942) { if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); } if (m == 32 && n == 32 && k == 1 && b == 2) return ROCDL::mfma_f32_32x32x1f32::getOperationName(); if (m == 16 && n == 16 && k == 1 && b == 4) return ROCDL::mfma_f32_16x16x1f32::getOperationName(); if (m == 4 && n == 4 && k == 1 && b == 16) return ROCDL::mfma_f32_4x4x1f32::getOperationName(); if (m == 32 && n == 32 && k == 2 && b == 1) return ROCDL::mfma_f32_32x32x2f32::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f32_16x16x4f32::getOperationName(); } if (sourceElem.isF16() && destElem.isF32()) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 16 && b == 1) return ROCDL::mfma_f32_32x32x16_f16::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1) return ROCDL::mfma_f32_16x16x32_f16::getOperationName(); } if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4f16::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4f16::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4f16::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8f16::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16f16::getOperationName(); } if (sourceElem.isBF16() && destElem.isF32()) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 16 && b == 1) return ROCDL::mfma_f32_32x32x16_bf16::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1) return ROCDL::mfma_f32_16x16x32_bf16::getOperationName(); } if (chipset >= kGfx90a) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); } if (m == 32 && n == 32 && k == 2 && b == 2) return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); if (m == 16 && n == 16 && k == 2 && b == 4) return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); if (m == 4 && n == 4 && k == 2 && b == 16) return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } if (sourceElem.isInteger(8) && destElem.isInteger(32)) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 32 && b == 1) return ROCDL::mfma_i32_32x32x32_i8::getOperationName(); if (m == 16 && n == 16 && k == 64 && b == 1) return ROCDL::mfma_i32_16x16x64_i8::getOperationName(); } if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_i32_16x16x4i8::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_i32_4x4x4i8::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_i32_32x32x8i8::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_i32_16x16x16i8::getOperationName(); if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942) return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942) return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); } if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f64_16x16x4f64::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 4) return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } return std::nullopt; } static std::optional smallFloatTypeToFormatCode(Type mlirElemType) { return llvm::TypeSwitch>(mlirElemType) .Case([](Float8E4M3FNType) { return 0u; }) .Case([](Float8E5M2Type) { return 1u; }) .Case([](Float6E2M3FNType) { return 2u; }) .Case([](Float6E3M2FNType) { return 3u; }) .Case([](Float4E2M1FNType) { return 4u; }) .Default(std::nullopt); } /// If there is a scaled MFMA instruction for the input element types `aType` /// and `bType`, output type `destType`, problem size M, N, K, and B (number of /// blocks) on the given `chipset`, return a tuple consisting of the /// OperationName of the intrinsic and the type codes that need to be passed to /// that intrinsic. Note that this is also used to implement some un-scaled /// MFMAs, since the compiler represents the ordinary instruction as a "scaled" /// MFMA with a scale of 0. static std::optional> mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset) { aType = getElementTypeOrSelf(aType); bType = getElementTypeOrSelf(bType); destType = getElementTypeOrSelf(destType); if (chipset < kGfx950) return std::nullopt; if (!isa(destType)) return std::nullopt; std::optional aTypeCode = smallFloatTypeToFormatCode(aType); std::optional bTypeCode = smallFloatTypeToFormatCode(bType); if (!aTypeCode || !bTypeCode) return std::nullopt; if (m == 32 && n == 32 && k == 64 && b == 1) return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), *aTypeCode, *bTypeCode}; if (m == 16 && n == 16 && k == 128 && b == 1) return std::tuple{ ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode, *bTypeCode}; return std::nullopt; } static std::optional> mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) { return mfmaOpToScaledIntrinsic( mfma.getSourceA().getType(), mfma.getSourceB().getType(), mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(), mfma.getBlocks(), chipset); } static std::optional> mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(), smfma.getSourceB().getType(), smfma.getDestC().getType(), smfma.getM(), smfma.getN(), smfma.getK(), 1u, chipset); } /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// for RDNA3/4 architectures. static std::optional wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3) { using fp8 = Float8E4M3FNType; using bf8 = Float8E5M2Type; // Handle k == 16 for RDNA3/4. if (k == 16) { // Common patterns for RDNA3 and RDNA4. if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); if (elemSourceType.isF16() && elemDestType.isF16()) return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isBF16()) return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); // RDNA3 specific patterns. if (isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } // RDNA4 specific patterns (fp8/bf8). if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } // Handle k == 32 for RDNA4. if (k == 32 && !isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); } return std::nullopt; } /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// for the gfx1250 architecture. static std::optional wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k) { using fp8 = Float8E4M3FNType; using bf8 = Float8E5M2Type; if (k == 4) { if (elemSourceType.isF32() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); return std::nullopt; } if (k == 32) { if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); if (elemSourceType.isF16() && elemDestType.isF16()) return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isBF16()) return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); return std::nullopt; } if (k == 64) { if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); } if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); return std::nullopt; } if (k == 128) { if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); } if (isa(elemSourceType) && isa(elemBSourceType)) { if (elemDestType.isF32()) return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); if (elemDestType.isF16()) return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); } return std::nullopt; } return std::nullopt; } /// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac) /// operation if one exists. This includes checking to ensure the intrinsic is /// supported on the architecture you are compiling for. static std::optional smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset) { bool isGfx950 = chipset >= kGfx950; auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); }; auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); }; uint32_t m = op.getM(), n = op.getN(), k = op.getK(); Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType()); Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType()); Type destElem = getElementTypeOrSelf(op.getDestC().getType()); if (m == 16 && n == 16 && k == 32) { if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) return ROCDL::smfmac_f32_16x16x32_f16::getOperationName(); if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName(); } if (m == 16 && n == 16 && k == 64) { if (isGfx950) { if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_f16::getOperationName(); if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName(); } if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && destElem.isInteger(32)) return ROCDL::smfmac_i32_16x16x64_i8::getOperationName(); if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName(); if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName(); if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName(); if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName(); } if (m == 16 && n == 16 && k == 128 && isGfx950) { if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && destElem.isInteger(32)) return ROCDL::smfmac_i32_16x16x128_i8::getOperationName(); if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName(); if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName(); if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName(); if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName(); } if (m == 32 && n == 32 && k == 16) { if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) return ROCDL::smfmac_f32_32x32x16_f16::getOperationName(); if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName(); } if (m == 32 && n == 32 && k == 32) { if (isGfx950) { if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_f16::getOperationName(); if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName(); } if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && destElem.isInteger(32)) return ROCDL::smfmac_i32_32x32x32_i8::getOperationName(); if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName(); if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName(); if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName(); if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName(); } if (m == 32 && n == 32 && k == 64 && isGfx950) { if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && destElem.isInteger(32)) return ROCDL::smfmac_i32_32x32x64_i8::getOperationName(); if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName(); if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName(); if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName(); if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName(); } return std::nullopt; } /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { auto sourceVectorType = cast(wmma.getSourceA().getType()); auto sourceBVectorType = cast(wmma.getSourceB().getType()); auto destVectorType = cast(wmma.getDestC().getType()); Type elemSourceType = sourceVectorType.getElementType(); Type elemBSourceType = sourceBVectorType.getElementType(); Type elemDestType = destVectorType.getElementType(); const uint32_t k = wmma.getK(); const bool isRDNA3 = chipset.majorVersion == 11; const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; // Handle RDNA3 and RDNA4. if (isRDNA3 || isRDNA4) return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, k, isRDNA3); // Handle gfx1250. if (chipset == kGfx1250) return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); return std::nullopt; } namespace { struct MFMAOpLowering : public ConvertOpToLLVMPattern { MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); Type intrinsicOutType = outType; if (auto outVecType = dyn_cast(outType)) if (outVecType.getElementType().isBF16()) intrinsicOutType = outVecType.clone(rewriter.getI16Type()); if (chipset.majorVersion != 9 || chipset < kGfx908) return op->emitOpError("MFMA only supported on gfx908+"); uint32_t getBlgpField = static_cast(op.getBlgp()); if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { if (chipset < kGfx942) return op.emitOpError("negation unsupported on older than gfx942"); getBlgpField |= op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); } std::optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); std::optional> maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value()) return op.emitOpError("no intrinsic matching MFMA size on given chipset"); bool isScaled = !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value(); if (isScaled && (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) { return op.emitOpError( "non-default abid, blgp, and cbsz aren't supported on MFMAs that can " "be scaled as those fields are used for type information"); } StringRef intrinsicName = isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic; // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file. bool allowBf16 = [&]() { if (chipset < kGfx950) return false; if (isScaled) return true; return intrinsicName.contains("16x16x32.bf16") || intrinsicName.contains("32x32x16.bf16"); }(); OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands({packSmallFloatVectorOperand( rewriter, loc, adaptor.getSourceA(), allowBf16), packSmallFloatVectorOperand( rewriter, loc, adaptor.getSourceB(), allowBf16), adaptor.getDestC()}); if (isScaled) { Value zero = createI32Constant(rewriter, loc, 0); auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero}); loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)}, {"blgp", rewriter.getI32IntegerAttr(bTypeCode)}, {"opselA", rewriter.getI32IntegerAttr(0)}, {"opselB", rewriter.getI32IntegerAttr(0)}}); } else { loweredOp.addAttributes( {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())}, {"abid", rewriter.getI32IntegerAttr(op.getAbid())}, {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}}); }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } }; struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern { ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType()); if (chipset.majorVersion != 9 || chipset < kGfx950) return op->emitOpError("scaled MFMA only supported on gfx908+"); std::optional> maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); if (!maybeScaledIntrinsic.has_value()) return op.emitOpError( "no intrinsic matching scaled MFMA size on given chipset"); auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands( {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()), packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()), adaptor.getDestC()}); loweredOp.addOperands( {/*scales A*/ castScaleOperand(rewriter, loc, adaptor.getScalesA()), /*scales B*/ castScaleOperand(rewriter, loc, adaptor.getScalesB())}); loweredOp.addAttributes( {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)}, {"blgp", rewriter.getI32IntegerAttr(bTypeCode)}, {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())}, {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}}); Value lowered = rewriter.create(loweredOp)->getResult(0); rewriter.replaceOp(op, lowered); return success(); } }; struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern { SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = typeConverter->convertType(op.getDestC().getType()); if (!outType) return rewriter.notifyMatchFailure(op, "type conversion failed"); // smfmac is supported on gfx942 and gfx950. if (chipset.majorVersion != 9 || chipset < kGfx942) return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+"); bool isGfx950 = chipset >= kGfx950; Value a = convertSparseMFMAVectorOperand(rewriter, loc, adaptor.getSourceA(), isGfx950); Value b = convertSparseMFMAVectorOperand(rewriter, loc, adaptor.getSourceB(), isGfx950); Value c = adaptor.getDestC(); std::optional maybeIntrinsic = smfmacOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError( "no intrinsic matching sparse MFMA on the given chipset"); // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32. Value sparseIdx = LLVM::BitcastOp::create( rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx()); OperationState loweredOp(loc, maybeIntrinsic.value()); loweredOp.addTypes(outType); loweredOp.addOperands({a, b, c, sparseIdx}); loweredOp.addAttributes( {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())}, {"abid", rewriter.getI32IntegerAttr(op.getAbid())}}); Value lowered = rewriter.create(loweredOp)->getResult(0); rewriter.replaceOp(op, lowered); return success(); } }; struct WMMAOpLowering : public ConvertOpToLLVMPattern { WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = typeConverter->convertType(op.getDestD().getType()); if (!outType) return rewriter.notifyMatchFailure(op, "type conversion failed"); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); bool isGFX1250 = chipset >= kGfx1250; // The WMMA operations represent vectors of bf16s as vectors of i16s // (except on gfx1250), so we need to bitcast bfloats to i16 and then // bitcast them back. auto aType = cast(adaptor.getSourceA().getType()); auto bType = cast(adaptor.getSourceB().getType()); auto destCType = cast(adaptor.getDestC().getType()); bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250; bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250; bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250; bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250; VectorType rawOutType = outType; if (castOutToI16) rawOutType = outType.clone(rewriter.getI16Type()); Value a = adaptor.getSourceA(); if (castAToI16) a = LLVM::BitcastOp::create(rewriter, loc, aType.clone(rewriter.getI16Type()), a); Value b = adaptor.getSourceB(); if (castBToI16) b = LLVM::BitcastOp::create(rewriter, loc, bType.clone(rewriter.getI16Type()), b); Value destC = adaptor.getDestC(); if (castDestCToI16) destC = LLVM::BitcastOp::create( rewriter, loc, destCType.clone(rewriter.getI16Type()), destC); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching WMMA on the given chipset"); if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); SmallVector operands; SmallVector attrs; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a, op.getSourceA(), operands, attrs, "signA"); wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b, op.getSourceB(), operands, attrs, "signB"); wmmaPushOutputOperand(rewriter, loc, typeConverter, destC, op.getSubwordOffset(), op.getClamp(), operands, attrs); OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes(rawOutType); loweredOp.addOperands(operands); loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; if (rawOutType != outType) maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); } }; struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern { ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = typeConverter->convertType(op.getDestD().getType()); if (!outType) return rewriter.notifyMatchFailure(op, "type conversion failed"); if (chipset < kGfx1250) return op->emitOpError("WMMA scale only supported on gfx1250+"); int64_t m = op.getM(); int64_t n = op.getN(); int64_t k = op.getK(); Type aElemType = getElementTypeOrSelf(op.getSourceA().getType()); Type bElemType = getElementTypeOrSelf(op.getSourceB().getType()); std::optional aFmtCode = smallFloatTypeToFormatCode(aElemType); std::optional bFmtCode = smallFloatTypeToFormatCode(bElemType); if (!aFmtCode || !bFmtCode) return op.emitOpError("unsupported element types for scaled_wmma"); // Get scale vector types and determine variant (scale vs scale16). auto scaleAVecType = cast(op.getScaleA().getType()); auto scaleBVecType = cast(op.getScaleB().getType()); if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements()) return op.emitOpError("scaleA and scaleB must have equal vector length"); // Extract scale format from element types. Type scaleAElemType = scaleAVecType.getElementType(); Type scaleBElemType = scaleBVecType.getElementType(); std::optional scaleAFmt = getWmmaScaleFormat(scaleAElemType); std::optional scaleBFmt = getWmmaScaleFormat(scaleBElemType); if (!scaleAFmt || !scaleBFmt) return op.emitOpError("unsupported scale element types"); // Determine which intrinsic to use based on dimensions. bool isScale16 = (scaleAVecType.getNumElements() == 8); std::optional intrinsicName = getScaledWmmaIntrinsicName(m, n, k, isScale16); if (!intrinsicName) return op.emitOpError("unsupported scaled_wmma dimensions: ") << m << "x" << n << "x" << k; SmallVector attrs; // The f4 variant does not have fmtA and fmtB attributes. bool is32x16 = (m == 32 && n == 16 && k == 128); if (!is32x16) { attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)); attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)); } // modC uses default value of 0. attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0)); // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the // half of the wave that is being selected (0 or 1). attrs.emplace_back( "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16)); attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt)); attrs.emplace_back( "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16)); attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt)); // Reuse flags use default value of false. attrs.emplace_back("reuseA", rewriter.getBoolAttr(false)); attrs.emplace_back("reuseB", rewriter.getBoolAttr(false)); // Convert typed float vectors to packed format. Value sourceA = packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()); Value sourceB = packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()); // Pack scale vectors into i32/i64. Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA()); Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB()); // Create the intrinsic call. OperationState loweredOp(loc, *intrinsicName); loweredOp.addTypes(outType); loweredOp.addOperands( {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB}); loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); rewriter.replaceOp(op, lowered->getResults()); return success(); } }; struct TransposeLoadOpLowering : public ConvertOpToLLVMPattern { TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset != kGfx950) return op.emitOpError("Non-gfx950 chipset not supported"); Location loc = op.getLoc(); auto srcMemRefType = cast(op.getSrc().getType()); // Elements in subbyte memrefs are stored non-contiguously, // reject if source is sub-byte memref. Use emulated memrefs instead. size_t srcElementSize = srcMemRefType.getElementType().getIntOrFloatBitWidth(); if (srcElementSize < 8) return op.emitOpError("Expect source memref to have at least 8 bits " "element size, got ") << srcElementSize; auto resultType = cast(op.getResult().getType()); Value srcPtr = getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), (adaptor.getSrcIndices())); size_t numElements = resultType.getNumElements(); size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth(); // ROCDL transpose load intrinsics return vectors of 32-bit integers, if // the element size is smaller than 16 bits. Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32, rewriter.getIntegerType(32)); Type llvmResultType = typeConverter->convertType(resultType); switch (elementTypeSize) { case 4: { assert(numElements == 16); auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 16: { assert(numElements == 4); rewriter.replaceOpWithNewOp(op, llvmResultType, srcPtr); break; } default: return op.emitOpError("Unsupported element size for transpose load"); } return success(); } }; struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern { GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset.majorVersion < 9 || chipset.majorVersion > 10) return op.emitOpError("pre-gfx9 and post-gfx10 not supported"); Location loc = op.getLoc(); auto srcMemRefType = cast(op.getSrc().getType()); auto dstMemRefType = cast(op.getDst().getType()); // TODO: instead of only transfering one element per thread, we could // augment it to transfer multiple elements per thread by issuing multiple // `global_load_lds` instructions. Type transferType = op.getTransferType(); int loadWidth = [&]() -> int { if (auto transferVectorType = dyn_cast(transferType)) { return (transferVectorType.getNumElements() * transferVectorType.getElementTypeBitWidth()) / 8; } return transferType.getIntOrFloatBitWidth() / 8; }(); // Currently only 1, 2, 4, 12 and 16 byte loads are supported. if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth)) return op.emitOpError("chipset unsupported element size"); if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth)) return op.emitOpError("Gather to LDS instructions with 12-byte and " "16-byte load widths are only supported on gfx950"); Value srcPtr = getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), (adaptor.getSrcIndices())); Value dstPtr = getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(), (adaptor.getDstIndices())); rewriter.replaceOpWithNewOp( op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth), /*offset=*/rewriter.getI32IntegerAttr(0), /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{}, ArrayAttr{}); return success(); } }; namespace { struct ExtPackedFp8OpLowering final : public ConvertOpToLLVMPattern { ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct ScaledExtPackedMatrixOpLowering final : public ConvertOpToLLVMPattern { ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedStochRoundFp8OpLowering final : public ConvertOpToLLVMPattern { PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct ScaledExtPackedOpLowering final : public ConvertOpToLLVMPattern { ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedScaledTruncOpLowering final : public ConvertOpToLLVMPattern { PackedScaledTruncOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // end namespace LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type v4i8 = getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type f32 = getTypeConverter()->convertType(op.getResult().getType()); Value source = adaptor.getSource(); auto sourceVecType = dyn_cast(op.getSource().getType()); auto resultVecType = dyn_cast(op.getResult().getType()); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { longVec = LLVM::InsertElementOp::create( rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } } else { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } } return success(); } int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf, int32_t firstScaleByte) { // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f* // operations, the attributes blockSize, sourceType, scaleWaveHalf, and // firstScaleByte are merged into a single attribute scaleSel. This is how // those values are merged together. (Note: scaleWaveHalf isn't a high-level // attribute but is derifed from firstScaleLane). assert(llvm::is_contained({16, 32}, blockSize)); assert(llvm::is_contained({4u, 6u, 8u}, bitWidth)); const bool isFp8 = bitWidth == 8; const bool isBlock16 = blockSize == 16; if (!isFp8) { int32_t bit0 = isBlock16; assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); int32_t bit1 = (firstScaleByte == 2) << 1; assert(llvm::is_contained({0, 1}, scaleWaveHalf)); int32_t bit2 = scaleWaveHalf << 2; return bit2 | bit1 | bit0; } int32_t bit0 = isBlock16; // firstScaleByte is guaranteed to be defined by two bits. assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); int32_t bits2and1 = firstScaleByte << 1; assert(llvm::is_contained({0, 1}, scaleWaveHalf)); int32_t bit3 = scaleWaveHalf << 3; int32_t bits = bit3 | bits2and1 | bit0; // These are invalid cases. assert(!llvm::is_contained( {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); return bits; } static std::optional scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { using fp4 = Float4E2M1FNType; using fp8 = Float8E4M3FNType; using bf8 = Float8E5M2Type; using fp6 = Float6E2M3FNType; using bf6 = Float6E3M2FNType; if (isa(srcElemType)) { if (destElemType.isF16()) return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); if (destElemType.isBF16()) return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); if (destElemType.isF32()) return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); return std::nullopt; } if (isa(srcElemType)) { if (destElemType.isF16()) return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); if (destElemType.isBF16()) return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); if (destElemType.isF32()) return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); return std::nullopt; } if (isa(srcElemType)) { if (destElemType.isF16()) return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); if (destElemType.isBF16()) return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); if (destElemType.isF32()) return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); return std::nullopt; } if (isa(srcElemType)) { if (destElemType.isF16()) return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); if (destElemType.isBF16()) return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); if (destElemType.isF32()) return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); return std::nullopt; } if (isa(srcElemType)) { if (destElemType.isF16()) return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); if (destElemType.isBF16()) return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); if (destElemType.isF32()) return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); return std::nullopt; } llvm_unreachable("invalid combination of element types for packed conversion " "instructions"); } LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite( ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { using fp4 = Float4E2M1FNType; using fp8 = Float8E4M3FNType; using bf8 = Float8E5M2Type; using fp6 = Float6E2M3FNType; using bf6 = Float6E3M2FNType; Location loc = op.getLoc(); if (chipset != kGfx1250) { return rewriter.notifyMatchFailure( loc, "Scaled fp packed conversion instructions are not available on target " "architecture and their emulation is not implemented"); } // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that // is being selected. int32_t scaleWaveHalf = op.getFirstScaleLane() / 16; int32_t firstScaleByte = op.getFirstScaleByte(); int32_t blockSize = op.getBlockSize(); auto sourceType = cast(op.getSource().getType()); auto srcElemType = cast(sourceType.getElementType()); unsigned bitWidth = srcElemType.getWidth(); auto targetType = cast(op.getResult().getType()); auto destElemType = cast(targetType.getElementType()); IntegerType i32 = rewriter.getI32Type(); Value source = adaptor.getSource(); Type llvmResultType = typeConverter->convertType(op.getResult().getType()); Type packedType = nullptr; if (isa(srcElemType)) { packedType = i32; packedType = getTypeConverter()->convertType(packedType); } else if (isa(srcElemType)) { packedType = VectorType::get(2, i32); packedType = getTypeConverter()->convertType(packedType); } else if (isa(srcElemType)) { packedType = VectorType::get(3, i32); packedType = getTypeConverter()->convertType(packedType); } else { llvm_unreachable("invalid element type for packed scaled ext"); } if (!packedType || !llvmResultType) { return rewriter.notifyMatchFailure(op, "type conversion failed"); } std::optional maybeIntrinsic = scaledExtPacked816ToIntrinsic(srcElemType, destElemType); if (!maybeIntrinsic.has_value()) return op.emitOpError( "no intrinsic matching packed scaled conversion on the given chipset"); int32_t scaleSel = getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte); Value castedScale = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, packedType, source); OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes({llvmResultType}); loweredOp.addOperands({castedSource, castedScale}); SmallVector attrs; attrs.push_back( NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); rewriter.replaceOp(op, lowered); return success(); } LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset != kGfx950) return rewriter.notifyMatchFailure( loc, "Scaled fp conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Value source = adaptor.getSource(); Value scale = adaptor.getScale(); VectorType sourceVecType = cast(op.getSource().getType()); Type sourceElemType = sourceVecType.getElementType(); VectorType destVecType = cast(op.getResult().getType()); Type destElemType = destVecType.getElementType(); VectorType packedVecType; if (isa(sourceElemType)) { VectorType v4i8 = VectorType::get(4, rewriter.getI8Type()); packedVecType = cast(getTypeConverter()->convertType(v4i8)); } else if (isa(sourceElemType)) { VectorType v8i4 = VectorType::get(8, rewriter.getI4Type()); packedVecType = cast(getTypeConverter()->convertType(v8i4)); } else { llvm_unreachable("invalid element type for scaled ext"); } // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { longVec = LLVM::InsertElementOp::create( rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else return failure(); return success(); } LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset != kGfx950) return rewriter.notifyMatchFailure( loc, "Scaled fp conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type v2i16 = getTypeConverter()->convertType( VectorType::get(2, rewriter.getI16Type())); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); VectorType sourceVecType = cast(op.getSource().getType()); Type sourceElemType = sourceVecType.getElementType(); Type intResultType = isa(resultElemType) ? i32 : v2i16; Value source = adaptor.getSource(); Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); source = LLVM::ZeroOp::create(rewriter, loc, v2); source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value source = adaptor.getSource(); Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp // operation into the corresponding ROCDL instructions. struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the source operand to the corresponding LLVM type Location loc = DppOp.getLoc(); Value src = adaptor.getSrc(); Value old = adaptor.getOld(); Type srcType = src.getType(); Type oldType = old.getType(); Type llvmType = nullptr; if (srcType.getIntOrFloatBitWidth() < 32) { llvmType = rewriter.getI32Type(); } else if (isa(srcType)) { llvmType = (srcType.getIntOrFloatBitWidth() == 32) ? rewriter.getF32Type() : rewriter.getF64Type(); } else if (isa(srcType)) { llvmType = (srcType.getIntOrFloatBitWidth() == 32) ? rewriter.getI32Type() : rewriter.getI64Type(); } auto llvmSrcIntType = typeConverter->convertType( rewriter.getIntegerType(srcType.getIntOrFloatBitWidth())); // If the source type is less of 32, use bitcast to convert it to i32. auto convertOperand = [&](Value operand, Type operandType) { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa(operandType)) { operand = LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); operand = LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; src = convertOperand(src, srcType); old = convertOperand(old, oldType); // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h enum DppCtrl : unsigned { ROW_SHL0 = 0x100, ROW_SHR0 = 0x110, ROW_ROR0 = 0x120, WAVE_SHL1 = 0x130, WAVE_ROL1 = 0x134, WAVE_SHR1 = 0x138, WAVE_ROR1 = 0x13C, ROW_MIRROR = 0x140, ROW_HALF_MIRROR = 0x141, BCAST15 = 0x142, BCAST31 = 0x143, }; auto kind = DppOp.getKind(); auto permArgument = DppOp.getPermArgument(); uint32_t DppCtrl = 0; switch (kind) { case DPPPerm::quad_perm: { auto quadPermAttr = cast(*permArgument); int32_t i = 0; for (auto elem : quadPermAttr.getAsRange()) { uint32_t num = elem.getInt(); DppCtrl |= num << (i * 2); i++; } break; } case DPPPerm::row_shl: { auto intAttr = cast(*permArgument); DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0; break; } case DPPPerm::row_shr: { auto intAttr = cast(*permArgument); DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0; break; } case DPPPerm::row_ror: { auto intAttr = cast(*permArgument); DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0; break; } case DPPPerm::wave_shl: DppCtrl = DppCtrl::WAVE_SHL1; break; case DPPPerm::wave_shr: DppCtrl = DppCtrl::WAVE_SHR1; break; case DPPPerm::wave_rol: DppCtrl = DppCtrl::WAVE_ROL1; break; case DPPPerm::wave_ror: DppCtrl = DppCtrl::WAVE_ROR1; break; case DPPPerm::row_mirror: DppCtrl = DppCtrl::ROW_MIRROR; break; case DPPPerm::row_half_mirror: DppCtrl = DppCtrl::ROW_HALF_MIRROR; break; case DPPPerm::row_bcast_15: DppCtrl = DppCtrl::BCAST15; break; case DPPPerm::row_bcast_31: DppCtrl = DppCtrl::BCAST31; break; } // Check for row_mask, bank_mask, bound_ctrl if they exist and create // constants auto rowMask = DppOp->getAttrOfType("row_mask").getInt(); auto bankMask = DppOp->getAttrOfType("bank_mask").getInt(); bool boundCtrl = DppOp->getAttrOfType("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes auto dppMovOp = ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa(srcType)) { result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } // We are replacing the AMDGPU_DPPOp instruction with the new // ROCDL_DPPMovOp instruction rewriter.replaceOp(DppOp, ValueRange(result)); return success(); } }; struct AMDGPUSwizzleBitModeLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value src = adaptor.getSrc(); SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, src, i32); unsigned andMask = op.getAndMask(); unsigned orMask = op.getOrMask(); unsigned xorMask = op.getXorMask(); // bit 15 is 0 for the BitMode swizzle. // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ unsigned mask = andMask | (orMask << 5) | (xorMask << 10); Value maskValue = createI32Constant(rewriter, loc, mask); SmallVector swizzled; for (Value v : decomposed) { Value res = ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType()); rewriter.replaceOp(op, result); return success(); } }; struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx950) return op->emitOpError("permlane_swap is only supported on gfx950+"); Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value src = adaptor.getSrc(); unsigned rowLength = op.getRowLength(); bool fi = op.getFetchInactive(); bool boundctrl = op.getBoundCtrl(); SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, src, i32); SmallVector permuted; for (Value v : decomposed) { Value res; Type i32pair = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), {v.getType(), v.getType()}); if (rowLength == 16) res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl); else if (rowLength == 32) res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl); else llvm_unreachable("unsupported row length"); Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1}); Value isEqual = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, vdst0, v); // Per `permlane(16|32)` semantics: if the first extracted element equals // 'v', the result is the second element; otherwise it is the first. Value vdstNew = LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0); permuted.emplace_back(vdstNew); } Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType()); rewriter.replaceOp(op, result); return success(); } }; //===----------------------------------------------------------------------===// // In-LDS Barrier Operations //===----------------------------------------------------------------------===// // Bit layout of ds_barrier_state (as i64): // [63:32] init count (32 bits) // [31:29] phase (3 bits) // [28:0] pending count (29 bits) constexpr int32_t kDsBarrierPendingCountBitWidth = 29; constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth; constexpr int32_t kDsBarrierInitCountPos = 32; constexpr int32_t kDsBarrierPendingCountMask = (1 << kDsBarrierPendingCountBitWidth) - 1; struct DsBarrierInitOpLowering : public ConvertOpToLLVMPattern { Chipset chipset; DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} LogicalResult matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("only supported on gfx1250+"); Location loc = op.getLoc(); Type i64 = rewriter.getI64Type(); MemRefType memrefType = cast(op.getBase().getType()); Value ptr = getStridedElementPtr(rewriter, loc, memrefType, adaptor.getBase(), adaptor.getIndices()); // Note: We give participants as the number of arrivals that have to occur // before the phase changes. Hardware changes the phase when updating the // pending count would underflow, so we subtract 1 to get the behavior we're // looking for. Value initCount = LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(), createI32Constant(rewriter, loc, 1)); // Just a bit of paranoia, but this also allows for configurable width if // that becomes a thing. Value countMask = createI32Constant(rewriter, loc, kDsBarrierPendingCountMask); Value maskedCount32 = LLVM::AndOp::create(rewriter, loc, initCount, countMask); Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32); Value initCountShifted = LLVM::ShlOp::create( rewriter, loc, maskedCount, createI64Constant(rewriter, loc, kDsBarrierInitCountPos)); Value barrierState = LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount); LLVM::StoreOp::create( rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false, /*isNonTemporal=*/false, /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release, /*syncscope=*/"workgroup"); rewriter.eraseOp(op); return success(); } }; struct DsBarrierPollStateOpLowering : public ConvertOpToLLVMPattern { Chipset chipset; DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} LogicalResult matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("only supported on gfx1250+"); Location loc = op.getLoc(); Type i64 = rewriter.getI64Type(); MemRefType memrefType = cast(op.getBase().getType()); Value ptr = getStridedElementPtr(rewriter, loc, memrefType, adaptor.getBase(), adaptor.getIndices()); // Atomic load with workgroup scope and acquire ordering should be what // we're looking for. rewriter.replaceOpWithNewOp( op, i64, ptr, /*alignment=*/8, /*volatile_=*/false, /*nontemporal=*/false, /*invariant=*/false, /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire, /*syncscope=*/"workgroup"); return success(); } }; struct DsAsyncBarrierArriveOpLowering : public ConvertOpToLLVMPattern { Chipset chipset; DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} LogicalResult matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("only supported on gfx1250+"); Location loc = op.getLoc(); MemRefType memrefType = cast(op.getBase().getType()); Value ptr = getStridedElementPtr(rewriter, loc, memrefType, adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); return success(); } }; struct DsBarrierArriveOpLowering : public ConvertOpToLLVMPattern { Chipset chipset; DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) { } LogicalResult matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("only supported on gfx1250+"); Location loc = op.getLoc(); Type i64 = rewriter.getI64Type(); MemRefType memrefType = cast(op.getBase().getType()); Value ptr = getStridedElementPtr(rewriter, loc, memrefType, adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); return success(); } }; struct DsBarrierStatePhaseOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value state = adaptor.getState(); Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state); Value phase = LLVM::LShrOp::create( rewriter, loc, noInitCount, createI32Constant(rewriter, loc, kDsBarrierPhasePos)); rewriter.replaceOp(op, phase); return success(); } }; struct DsBarrierStatePendingCountOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value state = adaptor.getState(); Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state); Value pendingCount = LLVM::AndOp::create( rewriter, loc, noInitCount, createI32Constant(rewriter, loc, static_cast(kDsBarrierPendingCountMask))); rewriter.replaceOp(op, pendingCount); return success(); } }; struct DsBarrierStateInitCountOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value state = adaptor.getState(); Value initCountI64 = LLVM::LShrOp::create( rewriter, loc, state, createI64Constant(rewriter, loc, kDsBarrierInitCountPos)); Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64); rewriter.replaceOp(op, initCount); return success(); } }; struct DsBarrierStatePhaseParityLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i1 = rewriter.getI1Type(); Value state = adaptor.getState(); Value noInitCount = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state); Value phase = LLVM::LShrOp::create( rewriter, loc, noInitCount, createI32Constant(rewriter, loc, kDsBarrierPhasePos)); Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase); rewriter.replaceOp(op, parity); return success(); } }; //===----------------------------------------------------------------------===// // Tensor Data Mover (TDM) //===----------------------------------------------------------------------===// static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc, Value accumulator, Value value, int64_t shift) { shift = shift % 32; Value shiftAmount; if (shift != 0) { shiftAmount = createI32Constant(rewriter, loc, shift % 32); value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount); } if (matchPattern(accumulator, mlir::m_Zero())) return value; constexpr bool isDisjoint = true; return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint); } template struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Adaptor = typename ConvertOpToLLVMPattern::OpAdaptor; AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(BaseOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("make_dma_base is only supported on gfx1250"); Location loc = op.getLoc(); constexpr int32_t constlen = 4; Value consts[constlen]; for (int64_t i = 0; i < constlen; ++i) consts[i] = createI32Constant(rewriter, loc, i); constexpr int32_t sgprslen = constlen; Value sgprs[sgprslen]; for (int64_t i = 0; i < sgprslen; ++i) { sgprs[i] = consts[0]; } sgprs[0] = consts[1]; if constexpr (BaseOp::isGather()) { sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30); auto type = cast(op.getResult().getType()); Type indexType = type.getIndexType(); unsigned indexSize = indexType.getIntOrFloatBitWidth(); assert(llvm::is_contained({16u, 32u}, indexSize) && "expected index_size to be 16 or 32"); unsigned idx = (indexSize / 16) - 1; if (idx) sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31); } ValueRange ldsIndices = adaptor.getLdsIndices(); Value lds = adaptor.getLds(); auto ldsMemRefType = cast(op.getLds().getType()); Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr( rewriter, loc, ldsMemRefType, lds, ldsIndices); ValueRange globalIndices = adaptor.getGlobalIndices(); Value global = adaptor.getGlobal(); auto globalMemRefType = cast(op.getGlobal().getType()); Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr( rewriter, loc, globalMemRefType, global, globalIndices); Type i32 = rewriter.getI32Type(); Type i64 = rewriter.getI64Type(); sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr); Value castForGlobalAddr = LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr); sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr); Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr, createI64Constant(rewriter, loc, 32)); Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift); Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1); highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask); sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30); Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); assert(v4i32 && "expected type conversion to succeed"); Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32); for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) result = LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant); rewriter.replaceOp(op, result); return success(); } }; template struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using OpAdaptor = typename ConvertOpToLLVMPattern::OpAdaptor; AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); } Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0) const { Value mask = op.getWorkgroupMask(); if (!mask) return sgpr0; Type i16 = rewriter.getI16Type(); mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask); Type i32 = rewriter.getI32Type(); Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask); return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0); } Value setDataSize(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { unsigned elementTypeWidthInBits = op.getElementTypeWidth(); assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) && "expected type width to be 8, 16, 32, or 64."); int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8); Value size = consts[idx]; return setValueAtOffset(rewriter, loc, sgpr0, size, 16); } Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!adaptor.getAtomicBarrierAddress()) return sgpr0; return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18); } Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!adaptor.getGlobalIncrement()) return sgpr0; // Value is ignored when in gather mode. // TODO: emit error earlier? return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19); } Value setPadEnable(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!op.getPadAmount()) return sgpr0; return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20); } Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!op.getWorkgroupMask()) return sgpr0; return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21); } Value setPadInterval(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!op.getPadAmount()) return sgpr0; // pre-condition: padInterval can be a power of two between 2 and 256. // TODO: Validation if the value breaks the pre-condition. // If the pre-condition fails, there is a possibility of // affecting the higher bits. In a following PR implement // RuntimeVerifiableOpInterface that instruments conditions that need to be // checked at runtime. IntegerType i32 = rewriter.getI32Type(); Value padInterval = adaptor.getPadInterval(); padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32, padInterval, false); padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]); // post-condition: padInterval can be a value between 0 and 7. return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22); } Value setPadAmount(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { if (!op.getPadAmount()) return sgpr0; // pre-condition: padAmount is a value between 1-128. // TODO: Validation if the value breaks the pre-condition. // If the pre-condition fails, there is a possibility of // affecting the higher bits. In a following PR implement // RuntimeVerifiableOpInterface that instruments conditions that need to be // checked at runtime. Value padAmount = adaptor.getPadAmount(); padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]); // post-condition: padAmount is a value between 0-127. return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25); } Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, ArrayRef consts) const { if (!adaptor.getAtomicBarrierAddress()) return sgpr1; Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress(); auto barrierAddressTy = cast(op.getAtomicBarrierAddress().getType()); ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices(); atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr( rewriter, loc, barrierAddressTy, atomicBarrierAddress, atomicBarrierIndices); IntegerType i32 = rewriter.getI32Type(); // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies // that the 3 LSBs are zero. // TODO: Validation if the value breaks the pre-condition. // In a following PR implement RuntimeVerifiableOpInterface // that instruments conditions that need to be checked at runtime. atomicBarrierAddress = LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress); atomicBarrierAddress = LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]); Value mask = createI32Constant(rewriter, loc, 0xFFFF); atomicBarrierAddress = LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask); return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32); } std::pair setTensorDimX(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, Value sgpr2, ArrayRef consts, uint64_t dimX, uint32_t offset) const { ArrayRef globalStaticSizes = adaptor.getGlobalStaticSizes(); ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes(); SmallVector mixedGlobalSizes = getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter); if (mixedGlobalSizes.size() <= dimX) return {sgpr1, sgpr2}; OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX); // pre-condition: tensorDimX is less than 2^32-1 // TODO: Validation if the value breaks the pre-condition. // In a following PR implement RuntimeVerifiableOpInterface that instruments // conditions that need to be checked at runtime. This could also be fixed // by saying that mixedGlobalSizes is a DynamicI32List. Value tensorDimX; if (auto attr = dyn_cast(tensorDimXOpFoldResult)) { tensorDimX = createI32Constant(rewriter, loc, cast(attr).getInt()); } else { IntegerType i32 = rewriter.getI32Type(); tensorDimX = cast(tensorDimXOpFoldResult); tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX); } sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset); Value c16 = createI32Constant(rewriter, loc, 16); Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16); sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16); return {sgpr1, sgpr2}; } std::pair setTensorDim0(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, Value sgpr2, ArrayRef consts) const { return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0, 48); } std::pair setTensorDim1(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr2, Value sgpr3, ArrayRef consts) const { return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1, 80); } Value setTileDimX(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr, ArrayRef consts, size_t dimX, int64_t offset) const { ArrayRef sharedStaticSizes = adaptor.getSharedStaticSizes(); ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes(); SmallVector mixedSharedSizes = getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter); if (mixedSharedSizes.size() <= dimX) return sgpr; OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX); // pre-condition: tileDimX is less than 2^16-1 // TODO: Validation if the value breaks the pre-condition. // If the pre-condition fails, there is a possibility of // affecting the higher bits. In a following PR implement // RuntimeVerifiableOpInterface that instruments conditions that need to be // checked at runtime. This could also be fixed by saying that // mixedSharedSizes is a DynamicI16List. Value tileDimX; if (auto attr = dyn_cast(tileDimXOpFoldResult)) { tileDimX = createI32Constant(rewriter, loc, cast(attr).getInt()); } else { IntegerType i32 = rewriter.getI32Type(); tileDimX = cast(tileDimXOpFoldResult); tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX); } return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset); } Value setTileDim0(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr3, ArrayRef consts) const { return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112); } Value setTileDim1(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr4, ArrayRef consts) const { return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128); } Value setValidIndices(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr4, ArrayRef consts) const { auto type = cast(op.getIndices().getType()); ArrayRef shape = type.getShape(); assert(shape.size() == 1 && "expected shape to be of rank 1."); unsigned length = shape.back(); assert(0 < length && length <= 16 && "expected length to be at most 16."); Value value = createI32Constant(rewriter, loc, length); return setValueAtOffset(rewriter, loc, sgpr4, value, 128); } Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr4, ArrayRef consts) const { if constexpr (DescriptorOp::isGather()) return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts); return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts); } Value setTileDim2(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr4, ArrayRef consts) const { // Value is ignored when in gather mode. if constexpr (DescriptorOp::isGather()) return sgpr4; return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144); } std::pair setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgprY, Value sgprZ, ArrayRef consts, size_t dimX, int64_t offset) const { ArrayRef globalStaticStrides = adaptor.getGlobalStaticStrides(); ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides(); SmallVector mixedGlobalStrides = getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter); if (mixedGlobalStrides.size() <= (dimX + 1)) return {sgprY, sgprZ}; OpFoldResult tensorDimXStrideOpFoldResult = *(mixedGlobalStrides.rbegin() + dimX + 1); // pre-condition: tensorDimXStride is less than 2^48-1 // TODO: Validation if the value breaks the pre-condition. // In a following PR implement RuntimeVerifiableOpInterface that instruments // conditions that need to be checked at runtime. Value tensorDimXStride; if (auto attr = dyn_cast(tensorDimXStrideOpFoldResult)) tensorDimXStride = createI64Constant(rewriter, loc, cast(attr).getInt()); else tensorDimXStride = cast(tensorDimXStrideOpFoldResult); constexpr int64_t first48bits = (1ll << 48) - 1; Value mask = createI64Constant(rewriter, loc, first48bits); tensorDimXStride = LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride); IntegerType i32 = rewriter.getI32Type(); Value tensorDimXStrideLow = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride); sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset); int64_t shift = (offset % 32) == 0 ? 32 : offset % 32; Value shiftVal = createI64Constant(rewriter, loc, shift); Value tensorDimXStrideHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal); tensorDimXStrideHigh = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh); sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh, offset + shift); return {sgprY, sgprZ}; } std::pair setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr5, Value sgpr6, ArrayRef consts) const { return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, 0, 160); } std::pair setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr5, Value sgpr6, ArrayRef consts) const { // Value is ignored when in gather mode. if constexpr (DescriptorOp::isGather()) return {sgpr5, sgpr6}; return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, 1, 208); } Value getDGroup1(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { Value sgprs[8]; for (int64_t i = 0; i < 8; ++i) { sgprs[i] = consts[0]; } sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]); sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[1] = setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts); std::tie(sgprs[1], sgprs[2]) = setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts); std::tie(sgprs[2], sgprs[3]) = setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts); sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts); sgprs[4] = setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts); sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts); std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride( op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts); std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride( op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts); IntegerType i32 = rewriter.getI32Type(); Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32)); assert(v8i32 && "expected type conversion to succeed"); Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32); for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) { dgroup1 = LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant); } return dgroup1; } Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts, int64_t dimX, int64_t offset) const { ArrayRef globalStaticSizes = adaptor.getGlobalStaticSizes(); ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes(); SmallVector mixedGlobalSizes = getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter); if (mixedGlobalSizes.size() <= static_cast(dimX)) return sgpr0; OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX); Value tensorDimX; if (auto attr = dyn_cast(tensorDimXOpFoldResult)) { tensorDimX = createI32Constant(rewriter, loc, cast(attr).getInt()); } else { IntegerType i32 = rewriter.getI32Type(); tensorDimX = cast(tensorDimXOpFoldResult); tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX); } return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset); } Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, ArrayRef consts) const { return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0); } Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter, Location loc, Value accumulator, Value value, int64_t shift) const { IntegerType i32 = rewriter.getI32Type(); value = LLVM::TruncOp::create(rewriter, loc, i32, value); return setValueAtOffset(rewriter, loc, accumulator, value, shift); } Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, ArrayRef consts, int64_t offset) const { Value ldsAddrIncrement = adaptor.getLdsIncrement(); return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset); } std::pair setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr2, Value sgpr3, ArrayRef consts, int64_t offset) const { Value globalAddrIncrement = adaptor.getGlobalIncrement(); sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2, globalAddrIncrement, offset); Value shift = createI64Constant(rewriter, loc, 32); globalAddrIncrement = LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift); constexpr int64_t first16BitsHigh = (1ll << 16) - 1; sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3, globalAddrIncrement, offset + 32); Value mask = createI32Constant(rewriter, loc, first16BitsHigh); sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask); return {sgpr2, sgpr3}; } Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, ArrayRef consts) const { Value ldsIncrement = op.getLdsIncrement(); constexpr int64_t dim = 3; constexpr int64_t offset = 32; if (!ldsIncrement) return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim, offset); return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts, offset); } std::pair setTensorDim2StrideOrGlobalAddrIncrement( DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr2, Value sgpr3, ArrayRef consts) const { Value globalIncrement = op.getGlobalIncrement(); constexpr int32_t dim = 2; constexpr int32_t offset = 64; if (!globalIncrement) return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, dim, offset); return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, offset); } Value setIterateCount(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr3, ArrayRef consts, int32_t offset) const { Value iterationCount = adaptor.getIterationCount(); IntegerType i32 = rewriter.getI32Type(); // pre-condition: iterationCount is in the inclusive interval [1, 256]. // TODO: validation if the value breaks the pre-condition. // If the pre-condition fails, there is a possibility of // affecting the higher bits. In a following PR implement // RuntimeVerifiableOpInterface that instruments conditions that need to be // checked at runtime. iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount); iterationCount = LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]); return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset); } Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr3, ArrayRef consts) const { Value iterateCount = op.getIterationCount(); constexpr int32_t dim = 2; constexpr int32_t offset = 112; if (!iterateCount) return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim, offset); return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset); } Value getDGroup2(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { if constexpr (DescriptorOp::isGather()) return getDGroup2Gather(op, adaptor, rewriter, loc, consts); return getDGroup2NonGather(op, adaptor, rewriter, loc, consts); } Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { IntegerType i32 = rewriter.getI32Type(); Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); assert(v4i32 && "expected type conversion to succeed."); bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2; if (onlyNeedsTwoDescriptors) return LLVM::ZeroOp::create(rewriter, loc, v4i32); constexpr int64_t sgprlen = 4; Value sgprs[sgprlen]; for (int i = 0; i < sgprlen; ++i) sgprs[i] = consts[0]; sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts); sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc, sgprs[1], consts); std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement( op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts); sgprs[3] = setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts); Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32); for (auto [sgpr, constant] : llvm::zip(sgprs, consts)) dgroup2 = LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant); return dgroup2; } Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts, bool firstHalf) const { IntegerType i32 = rewriter.getI32Type(); Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); assert(v4i32 && "expected type conversion to succeed."); Value indices = adaptor.getIndices(); auto vectorType = cast(indices.getType()); unsigned length = vectorType.getShape().back(); Type elementType = vectorType.getElementType(); unsigned maxLength = elementType == i32 ? 4 : 8; int32_t offset = firstHalf ? 0 : maxLength; unsigned discountedLength = std::max(static_cast(length - offset), 0); unsigned targetSize = std::min(maxLength, discountedLength); SmallVector indicesVector; for (unsigned i = offset; i < targetSize + offset; ++i) { Value idx; if (i < consts.size()) idx = consts[i]; else idx = createI32Constant(rewriter, loc, i); Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx); indicesVector.push_back(elem); } SmallVector indicesI32Vector; if (elementType == i32) { indicesI32Vector = indicesVector; } else { for (unsigned i = 0; i < targetSize; ++i) { Value index = indicesVector[i]; indicesI32Vector.push_back( LLVM::ZExtOp::create(rewriter, loc, i32, index)); } if ((targetSize % 2) != 0) // Add padding when not divisible by two. indicesI32Vector.push_back(consts[0]); } SmallVector indicesToInsert; if (elementType == i32) { indicesToInsert = indicesI32Vector; } else { unsigned size = indicesI32Vector.size() / 2; for (unsigned i = 0; i < size; ++i) { Value first = indicesI32Vector[2 * i]; Value second = indicesI32Vector[2 * i + 1]; Value joined = setValueAtOffset(rewriter, loc, first, second, 16); indicesToInsert.push_back(joined); } } Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32); for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts)) dgroup = LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant); return dgroup; } Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { return getGatherIndices(op, adaptor, rewriter, loc, consts, true); } std::pair setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr0, Value sgpr1, ArrayRef consts) const { constexpr int32_t dim = 3; constexpr int32_t offset = 0; return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts, dim, offset); } std::pair setTensorDim4(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr1, Value sgpr2, ArrayRef consts) const { constexpr int32_t dim = 4; constexpr int32_t offset = 48; return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim, offset); } Value setTileDim4(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, Value sgpr2, ArrayRef consts) const { constexpr int32_t dim = 4; constexpr int32_t offset = 80; return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset); } Value getDGroup3(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { if constexpr (DescriptorOp::isGather()) return getDGroup3Gather(op, adaptor, rewriter, loc, consts); return getDGroup3NonGather(op, adaptor, rewriter, loc, consts); } Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { IntegerType i32 = rewriter.getI32Type(); Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); assert(v4i32 && "expected type conversion to succeed."); bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2; if (onlyNeedsTwoDescriptors) return LLVM::ZeroOp::create(rewriter, loc, v4i32); constexpr int32_t sgprlen = 4; Value sgprs[sgprlen]; for (int i = 0; i < sgprlen; ++i) sgprs[i] = consts[0]; std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride( op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts); std::tie(sgprs[1], sgprs[2]) = setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts); sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts); Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32); for (auto [sgpr, constant] : llvm::zip(sgprs, consts)) dgroup3 = LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant); return dgroup3; } Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Location loc, ArrayRef consts) const { return getGatherIndices(op, adaptor, rewriter, loc, consts, false); } LogicalResult matchAndRewrite(DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError( "make_dma_descriptor is only supported on gfx1250"); Location loc = op.getLoc(); SmallVector consts; for (int64_t i = 0; i < 8; ++i) consts.push_back(createI32Constant(rewriter, loc, i)); Value dgroup0 = this->getDGroup0(adaptor); Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts); Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts); Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts); SmallVector results = {dgroup0, dgroup1, dgroup2, dgroup3}; rewriter.replaceOpWithMultiple(op, {results}); return success(); } }; template struct AMDGPUTensorLoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Adaptor = typename ConvertOpToLLVMPattern::OneToNOpAdaptor; AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(SourceOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset < kGfx1250) return op->emitOpError("is only supported on gfx1250"); ValueRange desc = adaptor.getDesc(); rewriter.replaceOpWithNewOp(op, desc[0], desc[1], desc[2], desc[3], /*cachePolicy=*/0, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); return success(); } }; struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase { using Base::Base; void runOnOperation() override { MLIRContext *ctx = &getContext(); FailureOr maybeChipset = Chipset::parse(chipset); if (failed(maybeChipset)) { emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); return signalPassFailure(); } RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); amdgpu::populateCommonGPUTypeAndAttributeConversions(converter); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::amdgpu::populateCommonGPUTypeAndAttributeConversions( TypeConverter &typeConverter) { populateGpuMemorySpaceAttributeConversions( typeConverter, [](gpu::AddressSpace space) { switch (space) { case gpu::AddressSpace::Global: return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace; case gpu::AddressSpace::Workgroup: return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace; case gpu::AddressSpace::Private: return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace; } llvm_unreachable("unknown address space enum value"); }); } void mlir::populateAMDGPUTypeAndAttributeConversions( TypeConverter &typeConverter) { typeConverter.addTypeAttributeConversion( [](BaseMemRefType type, amdgpu::AddressSpaceAttr as) -> TypeConverter::AttributeConversionResult { MLIRContext *ctx = as.getContext(); Type i64 = IntegerType::get(ctx, 64); switch (as.getValue()) { case amdgpu::AddressSpace::FatRawBuffer: return IntegerAttr::get(i64, 7); case amdgpu::AddressSpace::BufferRsrc: return IntegerAttr::get(i64, 8); case amdgpu::AddressSpace::FatStructuredBuffer: return IntegerAttr::get(i64, 9); } return TypeConverter::AttributeConversionResult::abort(); }); typeConverter.addConversion([&](DsBarrierStateType type) -> Type { return IntegerType::get(type.getContext(), 64); }); typeConverter.addConversion([&](TDMBaseType type) -> Type { Type i32 = IntegerType::get(type.getContext(), 32); return typeConverter.convertType(VectorType::get(4, i32)); }); typeConverter.addConversion([&](TDMGatherBaseType type) -> Type { Type i32 = IntegerType::get(type.getContext(), 32); return typeConverter.convertType(VectorType::get(4, i32)); }); typeConverter.addConversion( [&](TDMDescriptorType type, SmallVectorImpl &result) -> std::optional { Type i32 = IntegerType::get(type.getContext(), 32); Type v4i32 = typeConverter.convertType(VectorType::get(4, i32)); Type v8i32 = typeConverter.convertType(VectorType::get(8, i32)); llvm::append_values(result, v4i32, v8i32, v4i32, v4i32); return success(); }); auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types, ValueRange inputs, Location loc) -> SmallVector { // Only create unrealized_conversion_cast for TDMDescriptorType. // All other types which are not expected, should be // materialized by other target materialization functions. if (inputs.size() != 1) return {}; if (!isa(inputs[0].getType())) return {}; auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs); return cast.getResults(); }; typeConverter.addTargetMaterialization(addUnrealizedCast); } void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { populateAMDGPUTypeAndAttributeConversions(converter); patterns .add, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaBaseLowering, AMDGPULowerDescriptor, AMDGPULowerDescriptor, AMDGPUTensorLoadStoreOpLowering, AMDGPUTensorLoadStoreOpLowering, DsBarrierInitOpLowering, DsBarrierPollStateOpLowering, DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter, chipset); patterns.add(converter); }