[mlir][gpu] Extend mgpumoduleLoadJIT API to add assemblySize parameter (#189429)
When JITing SPIR-V using LevelZero API, it expects the length of the string since passed input data is a `void *`. Problem is, getting the length of the string is not possible using something like `strlen(reinterpret_cast<char *>(data))` in `mgpuModuleLoadJIT` implementation. Becasuse the SPIR-V binary contains null bytes (i.e., the data is binary SPIR-V, not null-terminated text). As a result we need to pass the `assmeblySize` via the `mgpuModuleLoadJIT(void* data, int optLevel, size_t assmeblySize)`.
This commit is contained in:
parent
f26b30ea35
commit
ffd29734cc
@ -124,8 +124,8 @@ mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
|
||||
return module;
|
||||
}
|
||||
|
||||
extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data,
|
||||
int optLevel) {
|
||||
extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule
|
||||
mgpuModuleLoadJIT(void *data, int optLevel, size_t /*assmeblySize*/) {
|
||||
ScopedContext scopedContext;
|
||||
CUmodule module = nullptr;
|
||||
char jitErrorBuffer[4096] = {0};
|
||||
|
||||
@ -520,10 +520,21 @@ extern "C" ze_module_handle_t mgpuModuleLoad(const void *data,
|
||||
return catchAll([&]() { return loadModule(data, gpuBlobSize); });
|
||||
}
|
||||
|
||||
extern "C" ze_module_handle_t mgpuModuleLoadJIT(void *data, int optLevel) {
|
||||
extern "C" ze_module_handle_t mgpuModuleLoadJIT(void *data, int optLevel,
|
||||
size_t assemblySize) {
|
||||
// Account for extra null terminator added in embedBinaryImpl.
|
||||
// A null terminator is added during embedding binary for assembly format to
|
||||
// support JIT paths that expect null-terminated strings. However, for SPIR-V
|
||||
// binary format, the null terminator is not expected. So we need to subtract
|
||||
// the null terminator when loading SPIR-V binary.
|
||||
assert((assemblySize == 0 ||
|
||||
reinterpret_cast<char *>(data)[assemblySize - 1] == 0) &&
|
||||
"Expected null terminator at the end of the assembly string.");
|
||||
size_t actualAssemblySize = assemblySize - 1;
|
||||
assert(actualAssemblySize % 4 == 0 &&
|
||||
"SPIR-V binary size must be a multiple of 4");
|
||||
return catchAll([&]() {
|
||||
return loadModule(data, strlen(reinterpret_cast<char *>(data)),
|
||||
ZE_MODULE_FORMAT_IL_SPIRV);
|
||||
return loadModule(data, actualAssemblySize, ZE_MODULE_FORMAT_IL_SPIRV);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -38,7 +38,8 @@ extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
|
||||
return module;
|
||||
}
|
||||
|
||||
extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) {
|
||||
extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel,
|
||||
size_t /*assmeblySize*/) {
|
||||
assert(false && "This function is not available in HIP.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
|
||||
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
@ -142,16 +141,24 @@ static LogicalResult embedBinaryImpl(StringRef moduleName,
|
||||
auto *loadBlock = BasicBlock::Create(module.getContext(), "entry", loadFn);
|
||||
builder.SetInsertPoint(loadBlock);
|
||||
Value *moduleObj = [&] {
|
||||
Constant *binarySize =
|
||||
ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
|
||||
if (object.getFormat() == gpu::CompilationTarget::Assembly) {
|
||||
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
|
||||
"mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
|
||||
"mgpuModuleLoadJIT", FunctionType::get(ptrTy,
|
||||
{
|
||||
ptrTy,
|
||||
i32Ty,
|
||||
i64Ty,
|
||||
},
|
||||
false));
|
||||
|
||||
Constant *optValue = ConstantInt::get(i32Ty, optLevel);
|
||||
return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
|
||||
return builder.CreateCall(moduleLoadFn,
|
||||
{serializedObj, optValue, binarySize});
|
||||
}
|
||||
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
|
||||
"mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
|
||||
Constant *binarySize =
|
||||
ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
|
||||
return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
|
||||
}();
|
||||
builder.CreateStore(moduleObj, modulePtr);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user