[MLIR][GPU] Generalize gpu.printf op lowering to LLVM call pattern. (#164297)

Existing pattern for lowering gpu.printf op to LLVM call uses fixed
function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel
Compute Runtime for GPU.

Also adds gpu.printf op pattern to GPU to LLVMSPV pass.
It may appear out of place, but integration test is added to XeVM
integration test as that is the current best folder for testing with
Intel Compute Runtime.
Test should be moved in the future if a better test folder is added.
This commit is contained in:
Sang Ik Lee 2025-10-23 08:32:53 -07:00 committed by GitHub
parent c6073d72ee
commit 150145486e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 71 additions and 7 deletions

View File

@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
/*isVarArg=*/true);
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
printfDecl.setCConv(callingConvention);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateStringConstant(
@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
call.setCConv(callingConvention);
rewriter.eraseOp(gpuPrintfOp);
return success();
}

View File

@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
namespace mlir {
@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
/// This pass will add a declaration of printf() to the GPUModule if needed
/// and separate out the format strings into global constants. For some
/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
/// will lower printf calls to appropriate device-side code
/// will lower printf calls to appropriate device-side code.
/// However not all backends use the same calling convention and function
/// naming.
/// For example, the LLVM SPIRV backend requires calling convention
/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be
/// mangled as "_Z6printfPU3AS2Kcz".
/// Default callingConvention is LLVM::cconv::CConv::C and
/// funcName is "printf" but they can be customized as needed.
struct GPUPrintfOpToLLVMCallLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
int addressSpace = 0)
GPUPrintfOpToLLVMCallLowering(
const LLVMTypeConverter &converter, int addressSpace = 0,
LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
StringRef funcName = "printf")
: ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
addressSpace(addressSpace) {}
addressSpace(addressSpace), callingConvention(callingConvention),
funcName(funcName) {}
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering
private:
int addressSpace;
LLVM::cconv::CConv callingConvention;
StringRef funcName;
};
/// Lowering of gpu.printf to a vprintf standard library.

View File

@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
gpu::ThreadIdOp>();
gpu::ThreadIdOp, gpu::PrintfOp>();
populateGpuToLLVMSPVConversionPatterns(converter, patterns);
populateGpuMemorySpaceAttributeConversions(converter);
patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2,
LLVM::cconv::CConv::SPIR_FUNC,
"_Z6printfPU3AS2Kcz");
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt %s -convert-gpu-to-llvm-spv | FileCheck %s
gpu.module @test_module {
// CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello: %d\0A\00") {addr_space = 2 : i32}
// CHECK: llvm.func spir_funccc @_Z6printfPU3AS2Kcz(!llvm.ptr<2>, ...) -> i32
// CHECK-LABEL: llvm.func spir_funccc @test_printf
// CHECK: (%[[ARG0:.*]]: i32)
gpu.func @test_printf(%arg0: i32) {
// CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<2>
// CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<2>) -> !llvm.ptr<2>, !llvm.array<11 x i8>
// CHECK-NEXT: %{{.*}} = llvm.call spir_funccc @_Z6printfPU3AS2Kcz(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32) -> i32
gpu.printf "Hello: %d\n", %arg0 : i32
gpu.return
}
}

View File

@ -0,0 +1,30 @@
// RUN: mlir-opt %s \
// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_sycl_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
// RUN: --shared-libs=%mlir_c_runner_utils \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s
module @test attributes {gpu.container_module} {
gpu.module @test_module {
gpu.func @test_printf(%arg0: i32, %arg1: f32) kernel {
gpu.printf "Hello: %d\n", %arg0 : i32
gpu.printf "Hello: %f\n", %arg1 : f32
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c11 = arith.constant 11 : i32
%c4 = arith.constant 4.0 : f32
// CHECK: Hello: 11
// CHECK: Hello: 4.000000
gpu.launch_func @test_module::@test_printf blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%c11 : i32, %c4 : f32)
return
}
}