diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index a621de081a0c..6eaf604c8ebb 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -74,10 +74,9 @@ def : Handle { let desc = "Handle of program object"; } -def : Typedef { - let name = "ol_kernel_handle_t"; - let desc = "Handle of kernel object"; - let value = "void *"; +def : Handle { + let name = "ol_symbol_handle_t"; + let desc = "Handle of an object in a device's memory for a specific program"; } def ErrorCode : Enum { @@ -112,6 +111,7 @@ def ErrorCode : Enum { Etor<"INVALID_DEVICE", "invalid device">, Etor<"INVALID_QUEUE", "invalid queue">, Etor<"INVALID_EVENT", "invalid event">, + Etor<"SYMBOL_KIND", "the operation does not support this symbol kind">, ]; } diff --git a/offload/liboffload/API/Kernel.td b/offload/liboffload/API/Kernel.td index 0913a036fa04..7cb3016afd59 100644 --- a/offload/liboffload/API/Kernel.td +++ b/offload/liboffload/API/Kernel.td @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file contains Offload API definitions related to the kernel handle +// This file contains Offload API definitions related to loading and launching +// kernels // //===----------------------------------------------------------------------===// @@ -14,12 +15,12 @@ def : Function { let name = "olGetKernel"; let desc = "Get a kernel from the function identified by `KernelName` in the given program."; let details = [ - "The kernel handle returned is owned by the device so does not need to be destroyed." + "Symbol handles are owned by the program and do not need to be manually destroyed." ]; let params = [ Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>, Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>, - Param<"ol_kernel_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT> + Param<"ol_symbol_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT> ]; let returns = []; } @@ -45,7 +46,7 @@ def : Function { let params = [ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN_OPTIONAL>, Param<"ol_device_handle_t", "Device", "handle of the device to execute on", PARAM_IN>, - Param<"ol_kernel_handle_t", "Kernel", "handle of the kernel", PARAM_IN>, + Param<"ol_symbol_handle_t", "Kernel", "handle of the kernel", PARAM_IN>, Param<"const void*", "ArgumentsData", "pointer to the kernel argument struct", PARAM_IN_OPTIONAL>, Param<"size_t", "ArgumentsSize", "size of the kernel argument struct", PARAM_IN>, Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs", "pointer to the struct containing launch size parameters", PARAM_IN>, @@ -55,5 +56,6 @@ def : Function { Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]>, Return<"OL_ERRC_INVALID_ARGUMENT", ["`ArgumentsSize > 0 && ArgumentsData == NULL`"]>, Return<"OL_ERRC_INVALID_DEVICE", ["If Queue is non-null but does not belong to Device"]>, + Return<"OL_ERRC_SYMBOL_KIND", ["The provided symbol is not a kernel"]>, ]; } diff --git a/offload/liboffload/API/OffloadAPI.td b/offload/liboffload/API/OffloadAPI.td index f9829155b6ce..1b78edf4c774 100644 --- a/offload/liboffload/API/OffloadAPI.td +++ b/offload/liboffload/API/OffloadAPI.td @@ -18,3 +18,4 @@ include "Queue.td" include "Event.td" include "Program.td" include "Kernel.td" +include "Symbol.td" diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td new file mode 100644 index 000000000000..cf4d45b09f03 --- /dev/null +++ b/offload/liboffload/API/Symbol.td @@ -0,0 +1,19 @@ +//===-- Symbol.td - Symbol definitions for Offload ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to the symbol handle. +// +//===----------------------------------------------------------------------===// + +def : Enum { + let name = "ol_symbol_kind_t"; + let desc = "The kind of a symbol"; + let etors =[ + Etor<"KERNEL", "a kernel object">, + ]; +} diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 6ed163d165a4..fa5d18c04404 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -84,9 +84,17 @@ struct ol_program_impl_t { DeviceImage(DeviceImage) {} plugin::DeviceImageTy *Image; std::unique_ptr ImageData; + std::vector> Symbols; __tgt_device_image DeviceImage; }; +struct ol_symbol_impl_t { + ol_symbol_impl_t(GenericKernelTy *Kernel) + : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {} + std::variant PluginImpl; + ol_symbol_kind_t Kind; +}; + namespace llvm { namespace offload { @@ -653,7 +661,7 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) { } Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName, - ol_kernel_handle_t *Kernel) { + ol_symbol_handle_t *Kernel) { auto &Device = Program->Image->getDevice(); auto KernelImpl = Device.constructKernel(KernelName); @@ -663,13 +671,15 @@ Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName, if (auto Err = KernelImpl->init(Device, *Program->Image)) return Err; - *Kernel = &*KernelImpl; + *Kernel = Program->Symbols + .emplace_back(std::make_unique(&*KernelImpl)) + .get(); return Error::success(); } Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, - ol_kernel_handle_t Kernel, const void *ArgumentsData, + ol_symbol_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, const ol_kernel_launch_size_args_t *LaunchSizeArgs, ol_event_handle_t *EventOut) { @@ -680,6 +690,10 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, "device specified does not match the device of the given queue"); } + if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL) + return createOffloadError(ErrorCode::SYMBOL_KIND, + "provided symbol is not a kernel"); + auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr; AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl); KernelArgsTy LaunchArgs{}; @@ -698,7 +712,7 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, // Don't do anything with pointer indirection; use arg data as-is LaunchArgs.Flags.IsCUDA = true; - auto *KernelImpl = reinterpret_cast(Kernel); + auto *KernelImpl = std::get(Kernel->PluginImpl); auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr, LaunchArgs, AsyncInfoWrapper); diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp index 5a7f8cd51be1..e443d9761f30 100644 --- a/offload/unittests/OffloadAPI/common/Fixtures.hpp +++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp @@ -120,7 +120,7 @@ struct OffloadKernelTest : OffloadProgramTest { RETURN_ON_FATAL_FAILURE(OffloadProgramTest::TearDown()); } - ol_kernel_handle_t Kernel = nullptr; + ol_symbol_handle_t Kernel = nullptr; }; struct OffloadQueueTest : OffloadDeviceTest { diff --git a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp index 83755e554ebe..34870f1fbf0a 100644 --- a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp +++ b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp @@ -14,13 +14,13 @@ using olGetKernelTest = OffloadProgramTest; OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetKernelTest); TEST_P(olGetKernelTest, Success) { - ol_kernel_handle_t Kernel = nullptr; + ol_symbol_handle_t Kernel = nullptr; ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); ASSERT_NE(Kernel, nullptr); } TEST_P(olGetKernelTest, InvalidNullProgram) { - ol_kernel_handle_t Kernel = nullptr; + ol_symbol_handle_t Kernel = nullptr; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olGetKernel(nullptr, "foo", &Kernel)); } @@ -32,7 +32,7 @@ TEST_P(olGetKernelTest, InvalidNullKernelPointer) { // Error code returning from plugin interface not yet supported TEST_P(olGetKernelTest, InvalidKernelName) { - ol_kernel_handle_t Kernel = nullptr; + ol_symbol_handle_t Kernel = nullptr; ASSERT_ERROR(OL_ERRC_NOT_FOUND, olGetKernel(Program, "invalid_kernel_name", &Kernel)); } diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp index 41d5c79c42de..acda4795edec 100644 --- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp +++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp @@ -43,7 +43,7 @@ struct LaunchSingleKernelTestBase : LaunchKernelTestBase { ASSERT_SUCCESS(olGetKernel(Program, kernel, &Kernel)); } - ol_kernel_handle_t Kernel = nullptr; + ol_symbol_handle_t Kernel = nullptr; }; #define KERNEL_TEST(NAME, KERNEL) \ @@ -70,7 +70,7 @@ struct LaunchMultipleKernelTestBase : LaunchKernelTestBase { ASSERT_SUCCESS(olGetKernel(Program, K, &Kernels[I++])); } - std::vector Kernels; + std::vector Kernels; }; #define KERNEL_MULTI_TEST(NAME, PROGRAM, ...) \