[Offload] Add an unloadBinary interface to PluginInterface (#143873)

This allows removal of a specific Image from a Device, rather than
requiring all image data to outlive the device they were created for.

This is required for `ol_program_handle_t`s, which now specify the
lifetime of the buffer used to create the program.
This commit is contained in:
Ross Brunton 2025-06-25 14:53:18 +01:00 committed by GitHub
parent e90ab0e342
commit 0870c8838b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 77 additions and 66 deletions

View File

@ -13,7 +13,9 @@
def : Function {
let name = "olCreateProgram";
let desc = "Create a program for the device from the binary image pointed to by `ProgData`.";
let details = [];
let details = [
"The provided `ProgData` will be copied and need not outlive the returned handle",
];
let params = [
Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>,
Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>,

View File

@ -480,6 +480,14 @@ Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
}
Error olDestroyProgram_impl(ol_program_handle_t Program) {
auto &Device = Program->Image->getDevice();
if (auto Err = Device.unloadBinary(Program->Image))
return Err;
auto &LoadedImages = Device.LoadedImages;
LoadedImages.erase(
std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image));
return olDestroy(Program);
}

View File

@ -2023,6 +2023,13 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}
Error unloadBinaryImpl(DeviceImageTy *Image) override {
AMDGPUDeviceImageTy &AMDImage = static_cast<AMDGPUDeviceImageTy &>(*Image);
// Unload the executable of the image.
return AMDImage.unloadExecutable();
}
/// Deinitialize the device and release its resources.
Error deinitImpl() override {
// Deinitialize the stream and event pools.
@ -2035,19 +2042,6 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
if (auto Err = AMDGPUSignalManager.deinit())
return Err;
// Close modules if necessary.
if (!LoadedImages.empty()) {
// Each image has its own module.
for (DeviceImageTy *Image : LoadedImages) {
AMDGPUDeviceImageTy &AMDImage =
static_cast<AMDGPUDeviceImageTy &>(*Image);
// Unload the executable of the image.
if (auto Err = AMDImage.unloadExecutable())
return Err;
}
}
// Invalidate agent reference.
Agent = {0};

View File

@ -752,6 +752,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
virtual Expected<DeviceImageTy *>
loadBinaryImpl(const __tgt_device_image *TgtImage, int32_t ImageId) = 0;
/// Unload a previously loaded Image from the device
Error unloadBinary(DeviceImageTy *Image);
virtual Error unloadBinaryImpl(DeviceImageTy *Image) = 0;
/// Setup the device environment if needed. Notice this setup may not be run
/// on some plugins. By default, it will be executed, but plugins can change
/// this behavior by overriding the shouldSetupDeviceEnvironment function.
@ -1036,6 +1040,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
BoolEnvar OMPX_TrackAllocationTraces =
BoolEnvar("OFFLOAD_TRACK_ALLOCATION_TRACES", false);
/// Array of images loaded into the device. Images are automatically
/// deallocated by the allocator.
llvm::SmallVector<DeviceImageTy *> LoadedImages;
private:
/// Get and set the stack size and heap size for the device. If not used, the
/// plugin can implement the setters as no-op and setting the output
@ -1086,10 +1094,6 @@ protected:
UInt32Envar OMPX_InitialNumStreams;
UInt32Envar OMPX_InitialNumEvents;
/// Array of images loaded into the device. Images are automatically
/// deallocated by the allocator.
llvm::SmallVector<DeviceImageTy *> LoadedImages;
/// The identifier of the device within the plugin. Notice this is not a
/// global device id and is not the device id visible to the OpenMP user.
const int32_t DeviceId;

View File

@ -821,26 +821,49 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
return Plugin::success();
}
Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
for (DeviceImageTy *Image : LoadedImages)
if (auto Err = callGlobalDestructors(Plugin, *Image))
return Err;
Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) {
if (auto Err = callGlobalDestructors(Plugin, *Image))
return Err;
if (OMPX_DebugKind.get() & uint32_t(DeviceDebugKind::AllocationTracker)) {
GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();
for (auto *Image : LoadedImages) {
DeviceMemoryPoolTrackingTy ImageDeviceMemoryPoolTracking = {0, 0, ~0U, 0};
GlobalTy TrackerGlobal("__omp_rtl_device_memory_pool_tracker",
sizeof(DeviceMemoryPoolTrackingTy),
&ImageDeviceMemoryPoolTracking);
if (auto Err =
GHandler.readGlobalFromDevice(*this, *Image, TrackerGlobal)) {
consumeError(std::move(Err));
continue;
}
DeviceMemoryPoolTracking.combine(ImageDeviceMemoryPoolTracking);
DeviceMemoryPoolTrackingTy ImageDeviceMemoryPoolTracking = {0, 0, ~0U, 0};
GlobalTy TrackerGlobal("__omp_rtl_device_memory_pool_tracker",
sizeof(DeviceMemoryPoolTrackingTy),
&ImageDeviceMemoryPoolTracking);
if (auto Err =
GHandler.readGlobalFromDevice(*this, *Image, TrackerGlobal)) {
consumeError(std::move(Err));
}
DeviceMemoryPoolTracking.combine(ImageDeviceMemoryPoolTracking);
}
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
auto ProfOrErr = Handler.readProfilingGlobals(*this, *Image);
if (!ProfOrErr)
return ProfOrErr.takeError();
if (!ProfOrErr->empty()) {
// Dump out profdata
if ((OMPX_DebugKind.get() & uint32_t(DeviceDebugKind::PGODump)) ==
uint32_t(DeviceDebugKind::PGODump))
ProfOrErr->dump();
// Write data to profiling file
if (auto Err = ProfOrErr->write())
return Err;
}
return unloadBinaryImpl(Image);
}
Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
for (auto &I : LoadedImages)
if (auto Err = unloadBinary(I))
return Err;
LoadedImages.clear();
if (OMPX_DebugKind.get() & uint32_t(DeviceDebugKind::AllocationTracker)) {
// TODO: Write this by default into a file.
printf("\n\n|-----------------------\n"
"| Device memory tracker:\n"
@ -856,25 +879,6 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
DeviceMemoryPoolTracking.AllocationMax);
}
for (auto *Image : LoadedImages) {
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
auto ProfOrErr = Handler.readProfilingGlobals(*this, *Image);
if (!ProfOrErr)
return ProfOrErr.takeError();
if (ProfOrErr->empty())
continue;
// Dump out profdata
if ((OMPX_DebugKind.get() & uint32_t(DeviceDebugKind::PGODump)) ==
uint32_t(DeviceDebugKind::PGODump))
ProfOrErr->dump();
// Write data to profiling file
if (auto Err = ProfOrErr->write())
return Err;
}
// Delete the memory manager before deinitializing the device. Otherwise,
// we may delete device allocations after the device is deinitialized.
if (MemoryManager)

View File

@ -358,6 +358,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
return Plugin::success();
}
Error unloadBinaryImpl(DeviceImageTy *Image) override {
assert(Context && "Invalid CUDA context");
// Each image has its own module.
CUDADeviceImageTy &CUDAImage = static_cast<CUDADeviceImageTy &>(*Image);
// Unload the module of the image.
if (auto Err = CUDAImage.unloadModule())
return Err;
return Plugin::success();
}
/// Deinitialize the device and release its resources.
Error deinitImpl() override {
if (Context) {
@ -372,20 +385,6 @@ struct CUDADeviceTy : public GenericDeviceTy {
if (auto Err = CUDAEventManager.deinit())
return Err;
// Close modules if necessary.
if (!LoadedImages.empty()) {
assert(Context && "Invalid CUDA context");
// Each image has its own module.
for (DeviceImageTy *Image : LoadedImages) {
CUDADeviceImageTy &CUDAImage = static_cast<CUDADeviceImageTy &>(*Image);
// Unload the module of the image.
if (auto Err = CUDAImage.unloadModule())
return Err;
}
}
if (Context) {
CUresult Res = cuDevicePrimaryCtxRelease(Device);
if (auto Err =