[Offload] Change ol_kernel_handle_t -> ol_symbol_handle_t (#147943)

In the future, we want `ol_symbol_handle_t` to represent both kernels
and global variables The first step in this process is a rename and
promotion to a "typed handle".
This commit is contained in:
Ross Brunton 2025-07-10 14:54:10 +01:00 committed by GitHub
parent 13ead00049
commit 466357ab51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 54 additions and 18 deletions

View File

@ -74,10 +74,9 @@ def : Handle {
let desc = "Handle of program object";
}
def : Typedef {
let name = "ol_kernel_handle_t";
let desc = "Handle of kernel object";
let value = "void *";
def : Handle {
let name = "ol_symbol_handle_t";
let desc = "Handle of an object in a device's memory for a specific program";
}
def ErrorCode : Enum {
@ -112,6 +111,7 @@ def ErrorCode : Enum {
Etor<"INVALID_DEVICE", "invalid device">,
Etor<"INVALID_QUEUE", "invalid queue">,
Etor<"INVALID_EVENT", "invalid event">,
Etor<"SYMBOL_KIND", "the operation does not support this symbol kind">,
];
}

View File

@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// This file contains Offload API definitions related to the kernel handle
// This file contains Offload API definitions related to loading and launching
// kernels
//
//===----------------------------------------------------------------------===//
@ -14,12 +15,12 @@ def : Function {
let name = "olGetKernel";
let desc = "Get a kernel from the function identified by `KernelName` in the given program.";
let details = [
"The kernel handle returned is owned by the device so does not need to be destroyed."
"Symbol handles are owned by the program and do not need to be manually destroyed."
];
let params = [
Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>,
Param<"ol_kernel_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT>
Param<"ol_symbol_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT>
];
let returns = [];
}
@ -45,7 +46,7 @@ def : Function {
let params = [
Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN_OPTIONAL>,
Param<"ol_device_handle_t", "Device", "handle of the device to execute on", PARAM_IN>,
Param<"ol_kernel_handle_t", "Kernel", "handle of the kernel", PARAM_IN>,
Param<"ol_symbol_handle_t", "Kernel", "handle of the kernel", PARAM_IN>,
Param<"const void*", "ArgumentsData", "pointer to the kernel argument struct", PARAM_IN_OPTIONAL>,
Param<"size_t", "ArgumentsSize", "size of the kernel argument struct", PARAM_IN>,
Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs", "pointer to the struct containing launch size parameters", PARAM_IN>,
@ -55,5 +56,6 @@ def : Function {
Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]>,
Return<"OL_ERRC_INVALID_ARGUMENT", ["`ArgumentsSize > 0 && ArgumentsData == NULL`"]>,
Return<"OL_ERRC_INVALID_DEVICE", ["If Queue is non-null but does not belong to Device"]>,
Return<"OL_ERRC_SYMBOL_KIND", ["The provided symbol is not a kernel"]>,
];
}

View File

@ -18,3 +18,4 @@ include "Queue.td"
include "Event.td"
include "Program.td"
include "Kernel.td"
include "Symbol.td"

View File

@ -0,0 +1,19 @@
//===-- Symbol.td - Symbol definitions for Offload ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains Offload API definitions related to the symbol handle.
//
//===----------------------------------------------------------------------===//
def : Enum {
let name = "ol_symbol_kind_t";
let desc = "The kind of a symbol";
let etors =[
Etor<"KERNEL", "a kernel object">,
];
}

View File

@ -84,9 +84,17 @@ struct ol_program_impl_t {
DeviceImage(DeviceImage) {}
plugin::DeviceImageTy *Image;
std::unique_ptr<llvm::MemoryBuffer> ImageData;
std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
__tgt_device_image DeviceImage;
};
struct ol_symbol_impl_t {
ol_symbol_impl_t(GenericKernelTy *Kernel)
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
std::variant<GenericKernelTy *> PluginImpl;
ol_symbol_kind_t Kind;
};
namespace llvm {
namespace offload {
@ -653,7 +661,7 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
}
Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
ol_kernel_handle_t *Kernel) {
ol_symbol_handle_t *Kernel) {
auto &Device = Program->Image->getDevice();
auto KernelImpl = Device.constructKernel(KernelName);
@ -663,13 +671,15 @@ Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
if (auto Err = KernelImpl->init(Device, *Program->Image))
return Err;
*Kernel = &*KernelImpl;
*Kernel = Program->Symbols
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
.get();
return Error::success();
}
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
ol_symbol_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
const ol_kernel_launch_size_args_t *LaunchSizeArgs,
ol_event_handle_t *EventOut) {
@ -680,6 +690,10 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
"device specified does not match the device of the given queue");
}
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
return createOffloadError(ErrorCode::SYMBOL_KIND,
"provided symbol is not a kernel");
auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
KernelArgsTy LaunchArgs{};
@ -698,7 +712,7 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
// Don't do anything with pointer indirection; use arg data as-is
LaunchArgs.Flags.IsCUDA = true;
auto *KernelImpl = reinterpret_cast<GenericKernelTy *>(Kernel);
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
LaunchArgs, AsyncInfoWrapper);

View File

@ -120,7 +120,7 @@ struct OffloadKernelTest : OffloadProgramTest {
RETURN_ON_FATAL_FAILURE(OffloadProgramTest::TearDown());
}
ol_kernel_handle_t Kernel = nullptr;
ol_symbol_handle_t Kernel = nullptr;
};
struct OffloadQueueTest : OffloadDeviceTest {

View File

@ -14,13 +14,13 @@ using olGetKernelTest = OffloadProgramTest;
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetKernelTest);
TEST_P(olGetKernelTest, Success) {
ol_kernel_handle_t Kernel = nullptr;
ol_symbol_handle_t Kernel = nullptr;
ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
ASSERT_NE(Kernel, nullptr);
}
TEST_P(olGetKernelTest, InvalidNullProgram) {
ol_kernel_handle_t Kernel = nullptr;
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
olGetKernel(nullptr, "foo", &Kernel));
}
@ -32,7 +32,7 @@ TEST_P(olGetKernelTest, InvalidNullKernelPointer) {
// Error code returning from plugin interface not yet supported
TEST_P(olGetKernelTest, InvalidKernelName) {
ol_kernel_handle_t Kernel = nullptr;
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(OL_ERRC_NOT_FOUND,
olGetKernel(Program, "invalid_kernel_name", &Kernel));
}

View File

@ -43,7 +43,7 @@ struct LaunchSingleKernelTestBase : LaunchKernelTestBase {
ASSERT_SUCCESS(olGetKernel(Program, kernel, &Kernel));
}
ol_kernel_handle_t Kernel = nullptr;
ol_symbol_handle_t Kernel = nullptr;
};
#define KERNEL_TEST(NAME, KERNEL) \
@ -70,7 +70,7 @@ struct LaunchMultipleKernelTestBase : LaunchKernelTestBase {
ASSERT_SUCCESS(olGetKernel(Program, K, &Kernels[I++]));
}
std::vector<ol_kernel_handle_t> Kernels;
std::vector<ol_symbol_handle_t> Kernels;
};
#define KERNEL_MULTI_TEST(NAME, PROGRAM, ...) \