[mlir][gpu][nvptx] Remove null terminator when outputting PTX (#133019)

PTX source files are expected to only contain ASCII text
(https://docs.nvidia.com/cuda/parallel-thread-execution/#source-format) and no null terminators.

`ptxas` has so far not enforced this but is moving towards doing so.
This revealed a problem where the null terminator is getting printed out
in the output file in MLIR path when outputting ptx directly. Only add the null on the assembly output path for JIT instead of in output of `moduleToObject `.
This commit is contained in:
modiking 2025-04-03 15:50:54 -07:00 committed by GitHub
parent f1c6612202
commit 9f2feeb189
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 7 deletions

View File

@ -722,12 +722,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
#undef DEBUG_TYPE
// Return PTX if the compilation target is `assembly`.
if (targetOptions.getCompilationTarget() ==
gpu::CompilationTarget::Assembly) {
// Make sure to include the null terminator.
StringRef bin(serializedISA->c_str(), serializedISA->size() + 1);
return SmallVector<char, 0>(bin.begin(), bin.end());
}
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Assembly)
return SmallVector<char, 0>(serializedISA->begin(), serializedISA->end());
std::optional<SmallVector<char, 0>> result;
moduleToObjectTimer.startTimer();

View File

@ -116,8 +116,11 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
llvm::Module *module = moduleTranslation.getLLVMModule();
// Embed the object as a global string.
// Add null for assembly output for JIT paths that expect null-terminated
// strings.
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
llvm::Constant *binary = llvm::ConstantDataArray::getString(
builder.getContext(), object.getObject().getValue(), false);
builder.getContext(), object.getObject().getValue(), addNull);
llvm::GlobalVariable *serializedObj =
new llvm::GlobalVariable(*module, binary->getType(), true,
llvm::GlobalValue::LinkageTypes::InternalLinkage,

View File

@ -130,6 +130,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) {
ASSERT_TRUE(
StringRef(object->data(), object->size()).contains("nvvm_kernel"));
ASSERT_TRUE(StringRef(object->data(), object->size()).count('\0') == 0);
}
}