This PR adds a `constant` address space to the` GPU dialect and lowerings to all GPU backends. Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
796 lines
32 KiB
C++
796 lines
32 KiB
C++
//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to generate ROCDLIR operations for higher-level
|
|
// GPU operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
|
|
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
|
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
|
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
|
|
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/GPU/Transforms/Passes.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "../GPUCommon/GPUOpsLowering.h"
|
|
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
// Truncate or extend the result depending on the index bitwidth specified
|
|
// by the LLVMTypeConverter options.
|
|
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value value,
|
|
const LLVMTypeConverter &converter) {
|
|
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
|
|
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
|
|
auto indexBitwidthType =
|
|
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
|
|
// TODO: use <=> in C++20.
|
|
if (indexBitwidth > intWidth) {
|
|
return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
|
|
}
|
|
if (indexBitwidth < intWidth) {
|
|
return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
|
|
}
|
|
return value;
|
|
}
|
|
|
|
/// Returns true if the given `gpu.func` can be safely called using the bare
|
|
/// pointer calling convention.
|
|
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
|
|
bool canBeBare = true;
|
|
for (Type type : func.getArgumentTypes())
|
|
if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
|
|
canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
|
|
return canBeBare;
|
|
}
|
|
|
|
static Value getLaneId(RewriterBase &rewriter, Location loc) {
|
|
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
|
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
|
|
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
|
|
NamedAttribute noundef = rewriter.getNamedAttr(
|
|
LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
|
|
NamedAttribute lowRange = rewriter.getNamedAttr(
|
|
LLVM::LLVMDialect::getRangeAttrName(),
|
|
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
|
|
APInt(32, 32)));
|
|
NamedAttribute highRange = rewriter.getNamedAttr(
|
|
LLVM::LLVMDialect::getRangeAttrName(),
|
|
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
|
|
APInt(32, 64)));
|
|
Value mbcntLo = ROCDL::MbcntLoOp::create(
|
|
rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
|
|
/*res_attrs=*/
|
|
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
|
|
Value laneId = ROCDL::MbcntHiOp::create(
|
|
rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
|
|
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
|
|
return laneId;
|
|
}
|
|
|
|
/// Maximum number of threads per block dimension on AMD GPUs.
|
|
static constexpr int64_t kMaxThreadsPerBlockDim = 1024;
|
|
|
|
/// Emits a call to an OCKL block/grid size function corresponding to
|
|
/// `indexKind` with argument `dim`, except that if the context around
|
|
/// `contextOp` gives an exact size for that dimension, return that as
|
|
/// an `i64` constant instead.
|
|
static Value getKnownOrOcklDim(RewriterBase &rewriter,
|
|
gpu::index_lowering::IndexKind indexKind,
|
|
gpu::Dimension dim, Operation *contextOp,
|
|
std::optional<uint32_t> opUpperBound) {
|
|
Location loc = contextOp->getLoc();
|
|
MLIRContext *context = contextOp->getContext();
|
|
|
|
auto i32Ty = IntegerType::get(context, 32);
|
|
auto i64Ty = IntegerType::get(context, 64);
|
|
|
|
if (std::optional<uint32_t> knownDim =
|
|
gpu::getKnownDimensionSizeAround(contextOp, indexKind, dim))
|
|
return LLVM::ConstantOp::create(rewriter, loc,
|
|
rewriter.getI64IntegerAttr(*knownDim));
|
|
|
|
int32_t dimParam = static_cast<int32_t>(dim);
|
|
|
|
StringRef functionName;
|
|
switch (indexKind) {
|
|
case gpu::index_lowering::IndexKind::Block:
|
|
functionName = "__ockl_get_local_size";
|
|
break;
|
|
case gpu::index_lowering::IndexKind::Grid:
|
|
functionName = "__ockl_get_num_groups";
|
|
break;
|
|
case gpu::index_lowering::IndexKind::Cluster:
|
|
case gpu::index_lowering::IndexKind::Other:
|
|
llvm_unreachable("Not valid index kinds for ockl lookup");
|
|
}
|
|
|
|
// Declare the ockl function: i64 @functionName(i32).
|
|
auto fnType = LLVM::LLVMFunctionType::get(i64Ty, {i32Ty});
|
|
Operation *moduleOp = contextOp->getParentWithTrait<OpTrait::SymbolTable>();
|
|
LLVM::LLVMFuncOp funcOp =
|
|
getOrDefineFunction(moduleOp, loc, rewriter, functionName, fnType);
|
|
|
|
// Create the call.
|
|
Value dimConst = LLVM::ConstantOp::create(rewriter, loc, i32Ty, dimParam);
|
|
auto callOp =
|
|
LLVM::CallOp::create(rewriter, loc, funcOp, ValueRange{dimConst});
|
|
|
|
LLVM::ConstantRangeAttr range;
|
|
if (opUpperBound) {
|
|
range = LLVM::ConstantRangeAttr::get(
|
|
context, APInt(64, 1),
|
|
APInt(64, static_cast<uint64_t>(*opUpperBound) + 1));
|
|
} else if (indexKind == gpu::index_lowering::IndexKind::Block) {
|
|
// Set the hardware limit for block ranges as the bounds on block dim calls.
|
|
range = LLVM::ConstantRangeAttr::get(context, APInt(64, 1),
|
|
APInt(64, kMaxThreadsPerBlockDim + 1));
|
|
}
|
|
if (range) {
|
|
callOp.setResAttrsAttr(rewriter.getArrayAttr(rewriter.getDictionaryAttr(
|
|
rewriter.getNamedAttr(LLVM::LLVMDialect::getRangeAttrName(), range))));
|
|
}
|
|
return callOp.getResult();
|
|
}
|
|
|
|
static constexpr StringLiteral amdgcnDataLayout =
|
|
"e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
|
|
"-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
|
|
"32-v32:"
|
|
"32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
|
|
"64-S32-A5-G1-ni:7:8:9";
|
|
|
|
namespace {
|
|
|
|
/// Lowers gpu.block_dim / gpu.grid_dim to direct __ockl_get_local_size /
|
|
/// __ockl_get_num_groups function calls.
|
|
template <typename OpTy>
|
|
struct GPUDimOpToOcklCall final : ConvertOpToLLVMPattern<OpTy> {
|
|
GPUDimOpToOcklCall(const LLVMTypeConverter &converter,
|
|
gpu::index_lowering::IndexKind indexKind)
|
|
: ConvertOpToLLVMPattern<OpTy>(converter), indexKind(indexKind) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
std::optional<uint32_t> opUpperBound;
|
|
if (auto bound = op.getUpperBound())
|
|
opUpperBound = static_cast<uint32_t>(bound->getZExtValue());
|
|
|
|
Value ocklCall = getKnownOrOcklDim(rewriter, indexKind, op.getDimension(),
|
|
op, opUpperBound);
|
|
Value result = truncOrExtToLLVMType(rewriter, loc, ocklCall,
|
|
*this->getTypeConverter());
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
const gpu::index_lowering::IndexKind indexKind;
|
|
};
|
|
|
|
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
|
|
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
// convert to:
|
|
// %mlo = call noundef range(i32 0, 32)
|
|
// @llvm.amdgcn.mbcnt.lo(-1, 0)
|
|
// followed by:
|
|
// %lid = call noundef range(i32 0, 64)
|
|
// @llvm.amdgcn.mbcnt.hi(-1, %mlo)
|
|
|
|
Value laneId = getLaneId(rewriter, loc);
|
|
// Truncate or extend the result depending on the index bitwidth specified
|
|
// by the LLVMTypeConverter options.
|
|
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
|
|
if (indexBitwidth > 32) {
|
|
laneId = LLVM::SExtOp::create(
|
|
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
|
|
} else if (indexBitwidth < 32) {
|
|
laneId = LLVM::TruncOp::create(
|
|
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
|
|
}
|
|
rewriter.replaceOp(op, {laneId});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
|
|
amdgpu::Chipset chipset)
|
|
: ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter),
|
|
chipset(chipset) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
LLVM::ConstantRangeAttr bounds = nullptr;
|
|
bool isBeforeGfx10 = chipset.majorVersion < 10;
|
|
if (auto upperBoundAttr = op.getUpperBoundAttr()) {
|
|
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
|
|
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
|
|
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
|
|
}
|
|
Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
|
|
rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
|
|
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
|
|
*getTypeConverter());
|
|
rewriter.replaceOp(op, {wavefrontOp});
|
|
return success();
|
|
}
|
|
|
|
const amdgpu::Chipset chipset;
|
|
};
|
|
|
|
struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
|
|
amdgpu::Chipset chipset)
|
|
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp>(converter), chipset(chipset) {
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
auto int32Type = rewriter.getI32Type();
|
|
|
|
Value subgroupId;
|
|
if (chipset.majorVersion >= 12) {
|
|
// For gfx12+, use the hardware wave.id register directly.
|
|
LLVM::ConstantRangeAttr bounds;
|
|
if (auto upperBoundAttr = op.getUpperBoundAttr())
|
|
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
|
|
/*bitWidth=*/32, /*lower=*/0,
|
|
/*upper=*/upperBoundAttr.getInt());
|
|
subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
|
|
} else {
|
|
// For older architectures, compute:
|
|
// subgroup_id = linearized_thread_id / subgroup_size
|
|
// where linearized_thread_id = tid.x + dim.x * (tid.y + dim.y * tid.z)
|
|
auto tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
|
|
auto tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
|
|
auto tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
|
|
auto setBoundFromContext = [&](Operation *tidOp, gpu::Dimension dim) {
|
|
if (LLVM::ConstantRangeAttr range =
|
|
gpu::index_lowering::getIndexOpRange(
|
|
op, dim, std::nullopt,
|
|
gpu::index_lowering::IndexKind::Block,
|
|
gpu::index_lowering::IntrType::Id, 32))
|
|
tidOp->setAttr("range", range);
|
|
};
|
|
setBoundFromContext(tidX, gpu::Dimension::x);
|
|
setBoundFromContext(tidY, gpu::Dimension::y);
|
|
setBoundFromContext(tidZ, gpu::Dimension::z);
|
|
|
|
auto flags =
|
|
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
|
|
|
|
auto getBlockDim = [&](gpu::Dimension dim) {
|
|
Value dim64 =
|
|
getKnownOrOcklDim(rewriter, gpu::index_lowering::IndexKind::Block,
|
|
dim, op, std::nullopt);
|
|
Value dimTrunc =
|
|
LLVM::TruncOp::create(rewriter, loc, int32Type, dim64, flags);
|
|
return dimTrunc;
|
|
};
|
|
Value dimX = getBlockDim(gpu::Dimension::x);
|
|
Value dimY = getBlockDim(gpu::Dimension::y);
|
|
|
|
// linearized = tid.x + dim.x * (tid.y + dim.y * tid.z)
|
|
// Thread IDs and dimensions are non-negative and small, so use nuw+nsw.
|
|
Value dimYxTidZ =
|
|
LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
|
|
Value tidYPlusDimYxTidZ =
|
|
LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
|
|
Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
|
|
tidYPlusDimYxTidZ, flags);
|
|
Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
|
|
dimXxInner, flags);
|
|
|
|
Value subgroupSize =
|
|
ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
|
|
subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
|
|
subgroupSize);
|
|
}
|
|
|
|
subgroupId =
|
|
truncOrExtToLLVMType(rewriter, loc, subgroupId, *getTypeConverter());
|
|
rewriter.replaceOp(op, subgroupId);
|
|
return success();
|
|
}
|
|
|
|
const amdgpu::Chipset chipset;
|
|
};
|
|
|
|
static bool isSupportedReadLaneType(Type type) {
|
|
// https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics
|
|
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
|
|
LLVM::LLVMPointerType>(type))
|
|
return true;
|
|
|
|
if (auto intType = dyn_cast<IntegerType>(type))
|
|
return llvm::is_contained({16, 32, 64},
|
|
static_cast<int>(intType.getWidth()));
|
|
|
|
if (auto vecType = dyn_cast<VectorType>(type)) {
|
|
Type elementType = vecType.getElementType();
|
|
if (elementType.isInteger(32))
|
|
return true;
|
|
|
|
if (vecType.getNumElements() == 2 &&
|
|
(isa<Float16Type, BFloat16Type>(elementType) ||
|
|
elementType.isInteger(16)))
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
struct GPUSubgroupBroadcastOpToROCDL
|
|
: public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value src = adaptor.getSrc();
|
|
if (isSupportedReadLaneType(src.getType())) {
|
|
Value result = createReadlaneOp(op, adaptor, rewriter, src);
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
Type i32 = rewriter.getI32Type();
|
|
Location loc = op.getLoc();
|
|
SmallVector<Value> decomposed;
|
|
if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
|
|
/*permitVariablySizedScalars=*/true)))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"Unexpected decomposition failure");
|
|
|
|
SmallVector<Value> results;
|
|
results.reserve(decomposed.size());
|
|
for (Value v : decomposed)
|
|
results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
|
|
|
|
Value result = LLVM::composeValue(rewriter, loc, results, src.getType());
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter,
|
|
Value src) {
|
|
if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
|
|
return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.getType(),
|
|
src, adaptor.getLane());
|
|
} else { // first_active_lane
|
|
return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
|
|
src.getType(), src);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct GPUBallotOpToROCDL : public ConvertOpToLLVMPattern<gpu::BallotOp> {
|
|
using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto intType = cast<IntegerType>(op.getType());
|
|
unsigned width = intType.getWidth();
|
|
|
|
// ROCDL ballot natively supports i32 and i64 for wavefront sizes of
|
|
// 32 and 64 lanes.
|
|
if (width != 32 && width != 64)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "rocdl.ballot only supports i32 and i64 result types");
|
|
|
|
rewriter.replaceOpWithNewOp<ROCDL::BallotOp>(op, op.getType(),
|
|
adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
|
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// Lowers a shuffle to the corresponding ROCDL ops.
|
|
///
|
|
/// Use the `width` argument to see if src lane is participating.
|
|
/// If not the dstLane would be itself.
|
|
///
|
|
/// Shuffle with DS Bpermute:
|
|
/// let shflMode = [xor, up, down, idx]
|
|
/// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
|
|
/// 1. curLaneId = using mbcnt.lo + mbcnt.hi
|
|
/// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
|
|
/// 3. dstLane = shflMode(curLaneId, step)
|
|
/// 4. isActiveSrcLane = dstLane < isActiveSrcLane
|
|
/// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
|
|
/// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
|
|
/// 7. bpermute(dwordAlignedDstLane, shfl_value).
|
|
///
|
|
LogicalResult
|
|
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op->getLoc();
|
|
Value initShflValue = adaptor.getValue();
|
|
|
|
Value srcLaneId = getLaneId(rewriter, loc);
|
|
|
|
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
|
Value width = adaptor.getWidth();
|
|
Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
|
|
Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
|
|
Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
|
|
Value widthOrZeroIfOutside =
|
|
LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
|
|
Value dstLane;
|
|
|
|
switch (op.getMode()) {
|
|
case gpu::ShuffleMode::UP:
|
|
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
|
|
adaptor.getOffset());
|
|
break;
|
|
case gpu::ShuffleMode::DOWN:
|
|
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
|
|
adaptor.getOffset());
|
|
break;
|
|
case gpu::ShuffleMode::XOR:
|
|
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
|
|
adaptor.getOffset());
|
|
break;
|
|
case gpu::ShuffleMode::IDX:
|
|
dstLane = adaptor.getOffset();
|
|
break;
|
|
}
|
|
Value isActiveSrcLane = LLVM::ICmpOp::create(
|
|
rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
|
|
Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
|
|
dstLane, srcLaneId);
|
|
Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
|
|
Value dwordAlignedDstLane =
|
|
LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
|
|
|
|
SmallVector<Value> decomposed;
|
|
if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
|
|
decomposed)))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to decompose value to i32");
|
|
SmallVector<Value> swizzled;
|
|
for (Value v : decomposed) {
|
|
Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
|
|
dwordAlignedDstLane, v);
|
|
swizzled.emplace_back(res);
|
|
}
|
|
Value shflValue =
|
|
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
|
|
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GPUBarrierOpLowering final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
|
|
GPUBarrierOpLowering(const LLVMTypeConverter &converter,
|
|
amdgpu::Chipset chipset)
|
|
: ConvertOpToLLVMPattern<gpu::BarrierOp>(converter), chipset(chipset) {}
|
|
|
|
amdgpu::Chipset chipset;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// Analyze the address_spaces attribute to determine fence behavior.
|
|
bool fenceGlobal = false;
|
|
bool fenceLDS = false;
|
|
std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
|
|
|
|
if (addrSpacesToFence) {
|
|
for (auto spaceAttr :
|
|
addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
|
|
switch (spaceAttr.getValue()) {
|
|
case gpu::AddressSpace::Global:
|
|
fenceGlobal = true;
|
|
break;
|
|
case gpu::AddressSpace::Workgroup:
|
|
fenceLDS = true;
|
|
break;
|
|
case gpu::AddressSpace::Private:
|
|
case gpu::AddressSpace::Constant:
|
|
// Private is thread-local, constant is read-only; no fencing needed.
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
// Default semantics match __syncthreads() and fence both global and LDS.
|
|
fenceGlobal = true;
|
|
fenceLDS = true;
|
|
}
|
|
|
|
Attribute mmra;
|
|
if (fenceLDS && !fenceGlobal) {
|
|
mmra =
|
|
rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
|
|
} else if (fenceGlobal && !fenceLDS) {
|
|
mmra = rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as",
|
|
"global");
|
|
}
|
|
|
|
constexpr llvm::StringLiteral scope = "workgroup";
|
|
|
|
bool emitFences = fenceGlobal || fenceLDS;
|
|
// Emit release fence if needed.
|
|
if (emitFences) {
|
|
auto relFence = LLVM::FenceOp::create(
|
|
rewriter, loc, LLVM::AtomicOrdering::release, scope);
|
|
if (mmra)
|
|
relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
|
|
mmra);
|
|
}
|
|
|
|
if (chipset.majorVersion < 12) {
|
|
ROCDL::SBarrierOp::create(rewriter, loc);
|
|
} else {
|
|
ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
|
|
ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
|
|
}
|
|
|
|
if (emitFences) {
|
|
auto acqFence = LLVM::FenceOp::create(
|
|
rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
|
|
if (mmra)
|
|
acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
|
|
mmra);
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Import the GPU Ops to ROCDL Patterns.
|
|
#include "GPUToROCDL.cpp.inc"
|
|
|
|
// A pass that replaces all occurrences of GPU device operations with their
|
|
// corresponding ROCDL equivalent.
|
|
//
|
|
// This pass only handles device code and is not meant to be run on GPU host
|
|
// code.
|
|
struct LowerGpuOpsToROCDLOpsPass final
|
|
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
|
|
using Base::Base;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
Base::getDependentDialects(registry);
|
|
registerConvertToLLVMDependentDialectLoading(registry);
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
gpu::GPUModuleOp m = getOperation();
|
|
MLIRContext *ctx = m.getContext();
|
|
|
|
auto llvmDataLayout = m->getAttrOfType<StringAttr>(
|
|
LLVM::LLVMDialect::getDataLayoutAttrName());
|
|
if (!llvmDataLayout) {
|
|
llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
|
|
m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
|
|
}
|
|
// Request C wrapper emission.
|
|
for (auto func : m.getOps<func::FuncOp>()) {
|
|
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
|
|
UnitAttr::get(ctx));
|
|
}
|
|
|
|
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
|
|
if (failed(maybeChipset)) {
|
|
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
|
|
return signalPassFailure();
|
|
}
|
|
|
|
/// Customize the bitwidth used for the device side index computations.
|
|
LowerToLLVMOptions options(
|
|
ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
|
|
options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
|
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(indexBitwidth);
|
|
|
|
if (useBarePtrCallConv) {
|
|
options.useBarePtrCallConv = true;
|
|
WalkResult canUseBarePointers =
|
|
m.walk([](gpu::GPUFuncOp func) -> WalkResult {
|
|
if (canBeCalledWithBarePointers(func))
|
|
return WalkResult::advance();
|
|
return WalkResult::interrupt();
|
|
});
|
|
if (canUseBarePointers.wasInterrupted()) {
|
|
emitError(UnknownLoc::get(ctx),
|
|
"bare pointer calling convention requires all memrefs to "
|
|
"have static shape and use the identity map");
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
|
|
// Apply in-dialect lowering. In-dialect lowering will replace
|
|
// ops which need to be lowered further, which is not supported by a
|
|
// single conversion pass.
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateGpuRewritePatterns(patterns);
|
|
populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
|
|
(void)applyPatternsGreedily(m, std::move(patterns));
|
|
}
|
|
|
|
LLVMTypeConverter converter(ctx, options);
|
|
amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
|
|
|
|
RewritePatternSet llvmPatterns(ctx);
|
|
LLVMConversionTarget target(getContext());
|
|
|
|
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
|
|
allowedDialects.end());
|
|
for (Dialect *dialect : ctx->getLoadedDialects()) {
|
|
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
|
|
// Empty `allowedDialectsSet` means all dialects are allowed.
|
|
if (!allowedDialectsSet.empty() && !allowed)
|
|
continue;
|
|
|
|
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
|
|
if (!iface) {
|
|
// Error out if dialect was explicily specified but doesn't implement
|
|
// conversion interface.
|
|
if (allowed) {
|
|
m.emitError()
|
|
<< "dialect does not implement ConvertToLLVMPatternInterface: "
|
|
<< dialect->getNamespace();
|
|
return signalPassFailure();
|
|
}
|
|
continue;
|
|
}
|
|
|
|
iface->populateConvertToLLVMConversionPatterns(target, converter,
|
|
llvmPatterns);
|
|
}
|
|
|
|
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
|
|
*maybeChipset);
|
|
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
|
|
*maybeChipset);
|
|
configureGpuToROCDLConversionLegality(target);
|
|
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
|
|
signalPassFailure();
|
|
auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
|
|
auto reqdWorkGroupSizeAttrHelper =
|
|
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
|
|
auto flatWorkGroupSizeAttrHelper =
|
|
rocdlDialect->getFlatWorkGroupSizeAttrHelper();
|
|
// Manually rewrite known block size attributes so the LLVMIR translation
|
|
// infrastructure can pick them up.
|
|
m.walk([&](LLVM::LLVMFuncOp op) {
|
|
if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
|
|
auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
|
|
// Also set up the rocdl.flat_work_group_size attribute to prevent
|
|
// conflicting metadata.
|
|
uint32_t flatSize = 1;
|
|
for (uint32_t size : blockSizes.asArrayRef()) {
|
|
flatSize *= size;
|
|
}
|
|
StringAttr flatSizeAttr =
|
|
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
|
|
flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
|
|
}
|
|
});
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
|
|
target.addIllegalOp<func::FuncOp>();
|
|
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
|
target.addLegalDialect<ROCDL::ROCDLDialect>();
|
|
target.addIllegalDialect<gpu::GPUDialect>();
|
|
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
|
|
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
|
|
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
|
|
// These ops are legal for f32 type.
|
|
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
|
|
return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
|
|
});
|
|
// TODO: Remove once we support replacing non-root ops.
|
|
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
|
|
}
|
|
|
|
void mlir::populateGpuToROCDLConversionPatterns(
|
|
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
|
mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset) {
|
|
using gpu::index_lowering::IndexKind;
|
|
using gpu::index_lowering::IntrType;
|
|
using mlir::gpu::amd::Runtime;
|
|
auto *rocdlDialect =
|
|
converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
|
|
populateWithGenerated(patterns);
|
|
patterns.add<
|
|
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
|
|
ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
|
|
converter, IndexKind::Block, IntrType::Id);
|
|
patterns.add<gpu::index_lowering::OpLowering<
|
|
gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
|
|
converter, IndexKind::Grid, IntrType::Id);
|
|
patterns.add<GPUDimOpToOcklCall<gpu::BlockDimOp>>(converter,
|
|
IndexKind::Block);
|
|
patterns.add<GPUDimOpToOcklCall<gpu::GridDimOp>>(converter, IndexKind::Grid);
|
|
patterns.add<GPUReturnOpLowering>(converter);
|
|
patterns.add<GPUFuncOpLowering>(
|
|
converter,
|
|
GPUFuncOpLoweringOptions{
|
|
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
|
|
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
|
|
rocdlDialect->getKernelAttrHelper().getName(),
|
|
rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
|
|
/*kernelClusterSizeAttributeName=*/{}});
|
|
if (Runtime::HIP == runtime) {
|
|
patterns.add<GPUPrintfOpToHIPLowering>(converter);
|
|
} else if (Runtime::OpenCL == runtime) {
|
|
// Use address space = 4 to match the OpenCL definition of printf()
|
|
patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
|
|
}
|
|
// TODO: Add alignment for workgroup memory
|
|
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
|
|
|
|
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
|
|
GPUSubgroupBroadcastOpToROCDL, GPUBallotOpToROCDL>(converter);
|
|
patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
|
|
GPUBarrierOpLowering>(converter, chipset);
|
|
|
|
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
|
|
}
|