[mlir][GPU] Add NVVM-specific cf.assert lowering (#120431)

This commit add an NVIDIA-specific lowering of `cf.assert` to to
`__assertfail`.

Note: `getUniqueFormatGlobalName`, `getOrCreateFormatStringConstant` and
`getOrDefineFunction` are moved to `GPUOpsLowering.h`, so that they can
be reused.
This commit is contained in:
Matthias Springer 2025-01-06 12:00:11 +01:00 committed by GitHub
parent a13bcf3ced
commit 599c739905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 258 additions and 64 deletions

View File

@ -3928,6 +3928,7 @@ public:
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
pattern);
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
if (!isAMDGCN)

View File

@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the

View File

@ -29,6 +29,10 @@ namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
///
/// Note: This function does not populate the default cf.assert lowering. That
/// is because some platforms have a custom cf.assert lowering. The default
/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
void populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

View File

@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
RewritePatternSet &patterns) const final {
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
}
};
} // namespace

View File

@ -19,6 +19,59 @@
using namespace mlir;
LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
Location loc, OpBuilder &b,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
}
return ret;
}
static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
LLVM::GlobalOp
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment, unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
StringAttr attr = b.getStringAttr(nullTermStr);
// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
globalOp.getAddrSpace() == addrSpace)
return globalOp;
// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
/// Create an global that contains the given format string. If a global with
/// the same format string exists already in the module, return that global.
static LLVM::GlobalOp getOrCreateFormatStringConstant(
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
llvm::SmallString<20> formatString(str);
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
StringAttr attr = b.getStringAttr(formatString);
// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
globalOp.getAddrSpace() == addrSpace)
return globalOp;
// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
addressSpace);
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
/*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);

View File

@ -14,6 +14,27 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// Helper Functions
//===----------------------------------------------------------------------===//
/// Find or create an external function declaration in the given module.
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type);
/// Create a global that contains the given string. If a global with the same
/// string already exists in the module, return that global.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment = 0,
unsigned addrSpace = 0);
//===----------------------------------------------------------------------===//
// Lowering Patterns
//===----------------------------------------------------------------------===//
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
/// a memref descriptor with these values and return it.

View File

@ -25,6 +25,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
/// Lowering of cf.assert into a conditional __assertfail.
struct AssertOpToAssertfailLowering
: public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
Location loc = assertOp.getLoc();
Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
Type ptrType = LLVM::LLVMPointerType::get(ctx);
Type voidType = LLVM::LLVMVoidType::get(ctx);
// Find or create __assertfail function declaration.
auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
auto assertfailType = LLVM::LLVMFunctionType::get(
voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "__assertfail", assertfailType);
assertfailDecl.setPassthroughAttr(
ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
// Split blocks and insert conditional branch.
// ^before:
// ...
// cf.cond_br %condition, ^after, ^assert
// ^assert:
// cf.assert
// cf.br ^after
// ^after:
// ...
Block *beforeBlock = assertOp->getBlock();
Block *assertBlock =
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
rewriter.create<cf::BranchOp>(loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
// Populate file name, file number and function name from the location of
// the AssertOp.
StringRef fileName = "(unknown)";
StringRef funcName = "(unknown)";
int32_t fileLine = 0;
while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
loc = callSiteLoc.getCallee();
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
fileName = fileLineColLoc.getFilename().strref();
fileLine = fileLineColLoc.getStartLine();
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
funcName = nameLoc.getName().strref();
if (auto fileLineColLoc =
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
fileName = fileLineColLoc.getFilename().strref();
fileLine = fileLineColLoc.getStartLine();
}
}
// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
Value assertFile = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
assertFunc, c1};
rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
arguments);
return success();
}
};
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"
@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

View File

@ -296,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);

View File

@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
arith::populateArithToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
cf::populateAssertToLLVMConversionPattern(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);

View File

@ -969,6 +969,35 @@ gpu.module @test_module_50 {
}
}
// CHECK-LABEL: gpu.module @test_module_51
// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("{{.*}}gpu-to-nvvm.mlir{{.*}}") {addr_space = 0 : i32}
// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]]
// CHECK: ^[[assert_block]]:
// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
// CHECK: llvm.br ^[[after_block]]
// CHECK: ^[[after_block]]:
// CHECK: llvm.return
// CHECK: }
gpu.module @test_module_51 {
gpu.func @test_assert(%arg0: i1) kernel {
cf.assert %arg0, "assert message"
gpu.return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module

View File

@ -0,0 +1,38 @@
// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
// RUN: --entry-point-result=void 2>&1 \
// RUN: | FileCheck %s
// CHECK-DAG: thread 0: print after passing assertion
// CHECK-DAG: thread 1: print after passing assertion
// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
// CHECK-NOT: print after failing assertion
module attributes {gpu.container_module} {
gpu.module @kernels {
gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
%0 = gpu.thread_id x
cf.assert %c1, "passing assertion"
gpu.printf "thread %lld: print after passing assertion\n" %0 : index
// Test callsite(callsite(name)) location.
cf.assert %c0, "failing assertion" loc(callsite(callsite("callee_func_name"("callee_file.cc":7:9) at "caller_file.cc":10:8) at "caller2_file.cc":11:12))
gpu.printf "thread %lld: print after failing assertion\n" %0 : index
gpu.return
}
}
func.func @main() {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0_i1 = arith.constant 0 : i1
%c1_i1 = arith.constant 1 : i1
gpu.launch_func @kernels::@test_assert
blocks in (%c1, %c1, %c1)
threads in (%c2, %c1, %c1)
args(%c0_i1 : i1, %c1_i1 : i1)
return
}
}