[mlir][GPU] Add NVVM-specific cf.assert
lowering (#120431)
This commit add an NVIDIA-specific lowering of `cf.assert` to to `__assertfail`. Note: `getUniqueFormatGlobalName`, `getOrCreateFormatStringConstant` and `getOrDefineFunction` are moved to `GPUOpsLowering.h`, so that they can be reused.
This commit is contained in:
parent
a13bcf3ced
commit
599c739905
@ -3928,6 +3928,7 @@ public:
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
pattern);
|
||||
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
|
||||
// Math operations that have not been converted yet must be converted
|
||||
// to Libm.
|
||||
if (!isAMDGCN)
|
||||
|
@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||
cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
|
||||
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// The only remaining operation to lower from the `toy` dialect, is the
|
||||
|
@ -29,6 +29,10 @@ namespace cf {
|
||||
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
|
||||
/// references have to remain alive during the entire pattern lifetime.
|
||||
///
|
||||
/// Note: This function does not populate the default cf.assert lowering. That
|
||||
/// is because some platforms have a custom cf.assert lowering. The default
|
||||
/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
|
||||
void populateControlFlowToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
|
@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AssertOpLowering,
|
||||
BranchOpLowering,
|
||||
CondBranchOpLowering,
|
||||
SwitchOpLowering>(converter);
|
||||
@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
|
||||
LLVMTypeConverter converter(ctx, options);
|
||||
RewritePatternSet patterns(ctx);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||
mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
|
||||
RewritePatternSet &patterns) const final {
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -19,6 +19,59 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
|
||||
Location loc, OpBuilder &b,
|
||||
StringRef name,
|
||||
LLVM::LLVMFunctionType type) {
|
||||
LLVM::LLVMFuncOp ret;
|
||||
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
|
||||
StringRef prefix) {
|
||||
// Get a unique global name.
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
return stringConstName;
|
||||
}
|
||||
|
||||
LLVM::GlobalOp
|
||||
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
|
||||
gpu::GPUModuleOp moduleOp, Type llvmI8,
|
||||
StringRef namePrefix, StringRef str,
|
||||
uint64_t alignment, unsigned addrSpace) {
|
||||
llvm::SmallString<20> nullTermStr(str);
|
||||
nullTermStr.push_back('\0'); // Null terminate for C
|
||||
auto globalType =
|
||||
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
|
||||
StringAttr attr = b.getStringAttr(nullTermStr);
|
||||
|
||||
// Try to find existing global.
|
||||
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
|
||||
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
|
||||
globalOp.getValueAttr() == attr &&
|
||||
globalOp.getAlignment().value_or(0) == alignment &&
|
||||
globalOp.getAddrSpace() == addrSpace)
|
||||
return globalOp;
|
||||
|
||||
// Not found: create new global.
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(moduleOp.getBody());
|
||||
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
|
||||
return b.create<LLVM::GlobalOp>(loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
name, attr, alignment, addrSpace);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
return success();
|
||||
}
|
||||
|
||||
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
|
||||
const char formatStringPrefix[] = "printfFormat_";
|
||||
// Get a unique global name.
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
return stringConstName;
|
||||
}
|
||||
|
||||
/// Create an global that contains the given format string. If a global with
|
||||
/// the same format string exists already in the module, return that global.
|
||||
static LLVM::GlobalOp getOrCreateFormatStringConstant(
|
||||
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
|
||||
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
|
||||
llvm::SmallString<20> formatString(str);
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
auto globalType =
|
||||
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
|
||||
StringAttr attr = b.getStringAttr(formatString);
|
||||
|
||||
// Try to find existing global.
|
||||
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
|
||||
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
|
||||
globalOp.getValueAttr() == attr &&
|
||||
globalOp.getAlignment().value_or(0) == alignment &&
|
||||
globalOp.getAddrSpace() == addrSpace)
|
||||
return globalOp;
|
||||
|
||||
// Not found: create new global.
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(moduleOp.getBody());
|
||||
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
|
||||
return b.create<LLVM::GlobalOp>(loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
name, attr, alignment, addrSpace);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
StringRef name,
|
||||
LLVM::LLVMFunctionType type) {
|
||||
LLVM::LLVMFuncOp ret;
|
||||
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
|
||||
LLVM::Linkage::External);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
Value printfDesc = printfBeginCall.getResult();
|
||||
|
||||
// Create the global op or find an existing one.
|
||||
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
|
||||
LLVM::GlobalOp global = getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
|
||||
|
||||
// Get a pointer to the format string's first element and pass it to printf()
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
|
||||
|
||||
// Create the global op or find an existing one.
|
||||
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
|
||||
addressSpace);
|
||||
LLVM::GlobalOp global = getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
|
||||
/*alignment=*/0, addressSpace);
|
||||
|
||||
// Get a pointer to the format string's first element
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
|
||||
|
||||
// Create the global op or find an existing one.
|
||||
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
|
||||
LLVM::GlobalOp global = getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
|
||||
|
||||
// Get a pointer to the format string's first element
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||
|
@ -14,6 +14,27 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Find or create an external function declaration in the given module.
|
||||
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
|
||||
OpBuilder &b, StringRef name,
|
||||
LLVM::LLVMFunctionType type);
|
||||
|
||||
/// Create a global that contains the given string. If a global with the same
|
||||
/// string already exists in the module, return that global.
|
||||
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
|
||||
gpu::GPUModuleOp moduleOp, Type llvmI8,
|
||||
StringRef namePrefix, StringRef str,
|
||||
uint64_t alignment = 0,
|
||||
unsigned addrSpace = 0);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
|
||||
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
|
||||
/// a memref descriptor with these values and return it.
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Transforms/Passes.h"
|
||||
@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Lowering of cf.assert into a conditional __assertfail.
|
||||
struct AssertOpToAssertfailLowering
|
||||
: public ConvertOpToLLVMPattern<cf::AssertOp> {
|
||||
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
Location loc = assertOp.getLoc();
|
||||
Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
|
||||
Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
|
||||
Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
|
||||
Type ptrType = LLVM::LLVMPointerType::get(ctx);
|
||||
Type voidType = LLVM::LLVMVoidType::get(ctx);
|
||||
|
||||
// Find or create __assertfail function declaration.
|
||||
auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
|
||||
auto assertfailType = LLVM::LLVMFunctionType::get(
|
||||
voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
|
||||
LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "__assertfail", assertfailType);
|
||||
assertfailDecl.setPassthroughAttr(
|
||||
ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
|
||||
|
||||
// Split blocks and insert conditional branch.
|
||||
// ^before:
|
||||
// ...
|
||||
// cf.cond_br %condition, ^after, ^assert
|
||||
// ^assert:
|
||||
// cf.assert
|
||||
// cf.br ^after
|
||||
// ^after:
|
||||
// ...
|
||||
Block *beforeBlock = assertOp->getBlock();
|
||||
Block *assertBlock =
|
||||
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
|
||||
Block *afterBlock =
|
||||
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
|
||||
rewriter.setInsertionPointToEnd(beforeBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
|
||||
assertBlock);
|
||||
rewriter.setInsertionPointToEnd(assertBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, afterBlock);
|
||||
|
||||
// Continue cf.assert lowering.
|
||||
rewriter.setInsertionPoint(assertOp);
|
||||
|
||||
// Populate file name, file number and function name from the location of
|
||||
// the AssertOp.
|
||||
StringRef fileName = "(unknown)";
|
||||
StringRef funcName = "(unknown)";
|
||||
int32_t fileLine = 0;
|
||||
while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
|
||||
loc = callSiteLoc.getCallee();
|
||||
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
|
||||
fileName = fileLineColLoc.getFilename().strref();
|
||||
fileLine = fileLineColLoc.getStartLine();
|
||||
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
|
||||
funcName = nameLoc.getName().strref();
|
||||
if (auto fileLineColLoc =
|
||||
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
|
||||
fileName = fileLineColLoc.getFilename().strref();
|
||||
fileLine = fileLineColLoc.getStartLine();
|
||||
}
|
||||
}
|
||||
|
||||
// Create constants.
|
||||
auto getGlobal = [&](LLVM::GlobalOp global) {
|
||||
// Get a pointer to the format string's first element.
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value start =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
return start;
|
||||
};
|
||||
Value assertMessage = getGlobal(getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
|
||||
Value assertFile = getGlobal(getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
|
||||
Value assertFunc = getGlobal(getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
|
||||
Value assertLine =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
|
||||
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
|
||||
|
||||
// Insert function call to __assertfail.
|
||||
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
|
||||
assertFunc, c1};
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
|
||||
arguments);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Import the GPU Ops to NVVM Patterns.
|
||||
#include "GPUToNVVM.cpp.inc"
|
||||
|
||||
@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
|
||||
using gpu::index_lowering::IndexKind;
|
||||
using gpu::index_lowering::IntrType;
|
||||
populateWithGenerated(patterns);
|
||||
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
|
||||
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
|
||||
converter);
|
||||
patterns.add<
|
||||
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
|
||||
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
|
||||
|
@ -296,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass
|
||||
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
|
||||
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
|
||||
|
@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
arith::populateArithToLLVMConversionPatterns(converter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||
cf::populateAssertToLLVMConversionPattern(converter, patterns);
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
|
||||
populateFuncToLLVMConversionPatterns(converter, patterns);
|
||||
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
||||
|
@ -969,6 +969,35 @@ gpu.module @test_module_50 {
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.module @test_module_51
|
||||
// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
|
||||
// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("{{.*}}gpu-to-nvvm.mlir{{.*}}") {addr_space = 0 : i32}
|
||||
// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
|
||||
// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
|
||||
// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
|
||||
// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]]
|
||||
// CHECK: ^[[assert_block]]:
|
||||
// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
|
||||
// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
|
||||
// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
|
||||
// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
|
||||
// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
|
||||
// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
|
||||
// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
|
||||
// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
|
||||
// CHECK: llvm.br ^[[after_block]]
|
||||
// CHECK: ^[[after_block]]:
|
||||
// CHECK: llvm.return
|
||||
// CHECK: }
|
||||
|
||||
gpu.module @test_module_51 {
|
||||
gpu.func @test_assert(%arg0: i1) kernel {
|
||||
cf.assert %arg0, "assert message"
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
|
||||
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
|
||||
|
38
mlir/test/Integration/GPU/CUDA/assert.mlir
Normal file
38
mlir/test/Integration/GPU/CUDA/assert.mlir
Normal file
@ -0,0 +1,38 @@
|
||||
// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: --shared-libs=%mlir_cuda_runtime \
|
||||
// RUN: --shared-libs=%mlir_runner_utils \
|
||||
// RUN: --entry-point-result=void 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK-DAG: thread 0: print after passing assertion
|
||||
// CHECK-DAG: thread 1: print after passing assertion
|
||||
// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
|
||||
// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
|
||||
// CHECK-NOT: print after failing assertion
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
gpu.module @kernels {
|
||||
gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
|
||||
%0 = gpu.thread_id x
|
||||
cf.assert %c1, "passing assertion"
|
||||
gpu.printf "thread %lld: print after passing assertion\n" %0 : index
|
||||
// Test callsite(callsite(name)) location.
|
||||
cf.assert %c0, "failing assertion" loc(callsite(callsite("callee_func_name"("callee_file.cc":7:9) at "caller_file.cc":10:8) at "caller2_file.cc":11:12))
|
||||
gpu.printf "thread %lld: print after failing assertion\n" %0 : index
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0_i1 = arith.constant 0 : i1
|
||||
%c1_i1 = arith.constant 1 : i1
|
||||
gpu.launch_func @kernels::@test_assert
|
||||
blocks in (%c1, %c1, %c1)
|
||||
threads in (%c2, %c1, %c1)
|
||||
args(%c0_i1 : i1, %c1_i1 : i1)
|
||||
return
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user