[Offload] olLaunchHostFunction (#152482)

Add an `olLaunchHostFunction` method that allows enqueueing host work
to the stream.
This commit is contained in:
Ross Brunton 2025-08-15 09:39:48 +01:00 committed by GitHub
parent 598562077a
commit 30c7951136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 233 additions and 2 deletions

View File

@ -31,6 +31,13 @@ class IsHandleType<string Type> {
!ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
}
// Does the type end with '_cb_t'?
class IsCallbackType<string Type> {
// size("_cb_t") == 5
bit ret = !if(!lt(!size(Type), 5), 0,
!ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
}
// Does the type end with '*'?
class IsPointerType<string Type> {
bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
TypeInfo type_info = TypeInfo<"", "">;
bit IsHandle = IsHandleType<type>.ret;
bit IsPointer = IsPointerType<type>.ret;
bit IsCallback = IsCallbackType<type>.ret;
}
// A parameter whose range is described by other parameters in the function.
@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
}
class ShouldCheckPointer<Param P> {
bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
}
// For a list of returns that contains a specific return code, find and append

View File

@ -108,3 +108,29 @@ def : Function {
Return<"OL_ERRC_INVALID_QUEUE">
];
}
def : FptrTypedef {
let name = "ol_host_function_cb_t";
let desc = "Host function for use by `olLaunchHostFunction`.";
let params = [
Param<"void *", "UserData", "user specified data passed into `olLaunchHostFunction`.", PARAM_IN>,
];
let return = "void";
}
def : Function {
let name = "olLaunchHostFunction";
let desc = "Enqueue a callback function on the host.";
let details = [
"The provided function will be called from the same process as the one that called `olLaunchHostFunction`.",
"The callback will not run until all previous work submitted to the queue has completed.",
"The callback must return before any work submitted to the queue after it is started.",
"The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
];
let params = [
Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
Param<"ol_host_function_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
];
let returns = [];
}

View File

@ -833,5 +833,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
}
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
ol_host_function_cb_t Callback,
void *UserData) {
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
Queue->AsyncInfo);
}
} // namespace offload
} // namespace llvm

View File

@ -1063,6 +1063,20 @@ private:
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;
/// Wrapper function for implementing host callbacks
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
AMDGPUSignalTy *OutputSignal,
void (*Callback)(void *), void *UserData) {
// The wait call will not error in this context.
if (InputSignal)
if (auto Err = InputSignal->wait())
reportFatalInternalError(std::move(Err));
Callback(UserData);
OutputSignal->signal();
}
/// Return the current number of asynchronous operations on the stream.
uint32_t size() const { return NextSlot; }
@ -1495,6 +1509,31 @@ public:
OutputSignal->get());
}
Error pushHostCallback(void (*Callback)(void *), void *UserData) {
// Retrieve an available signal for the operation's output.
AMDGPUSignalTy *OutputSignal = nullptr;
if (auto Err = SignalManager.getResource(OutputSignal))
return Err;
OutputSignal->reset();
OutputSignal->increaseUseCount();
AMDGPUSignalTy *InputSignal;
{
std::lock_guard<std::mutex> Lock(Mutex);
// Consume stream slot and compute dependencies.
InputSignal = consume(OutputSignal).second;
}
// "Leaking" the thread here is consistent with other work added to the
// queue. The input and output signals will remain valid until the output is
// signaled.
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
.detach();
return Plugin::success();
}
/// Synchronize with the stream. The current thread waits until all operations
/// are finalized and it performs the pending post actions (i.e., releasing
/// intermediate buffers).
@ -2553,6 +2592,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
AMDGPUStreamTy *Stream = nullptr;
if (auto Err = getStream(AsyncInfo, Stream))
return Err;
return Stream->pushHostCallback(Callback, UserData);
};
/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);

View File

@ -965,6 +965,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error initDeviceInfo(__tgt_device_info *DeviceInfo);
virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;
/// Enqueue a host call to AsyncInfo
Error enqueueHostCall(void (*Callback)(void *), void *UserData,
__tgt_async_info *AsyncInfo);
virtual Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) = 0;
/// Create an event.
Error createEvent(void **EventPtrStorage);
virtual Error createEventImpl(void **EventPtrStorage) = 0;

View File

@ -1589,6 +1589,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
return Err;
}
Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData,
__tgt_async_info *AsyncInfo) {
AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper);
AsyncInfoWrapper.finalize(Err);
return Err;
}
Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
assert(DeviceInfo && "Invalid device info");

View File

@ -873,6 +873,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
return Plugin::success();
}
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
if (auto Err = setContext())
return Err;
CUstream Stream;
if (auto Err = getStream(AsyncInfo, Stream))
return Err;
CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
};
/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);

View File

@ -320,6 +320,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
"initDeviceInfoImpl not supported");
}
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
Callback(UserData);
return Plugin::success();
};
/// This plugin does not support the event API. Do nothing without failing.
Error createEventImpl(void **EventPtrStorage) override {
*EventPtrStorage = nullptr;

View File

@ -41,7 +41,8 @@ add_offload_unittest("queue"
queue/olDestroyQueue.cpp
queue/olGetQueueInfo.cpp
queue/olGetQueueInfoSize.cpp
queue/olWaitEvents.cpp)
queue/olWaitEvents.cpp
queue/olLaunchHostFunction.cpp)
add_offload_unittest("symbol"
symbol/olGetSymbol.cpp

View File

@ -0,0 +1,107 @@
//===------- Offload API tests - olLaunchHostFunction ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../common/Fixtures.hpp"
#include <OffloadAPI.h>
#include <gtest/gtest.h>
#include <thread>
struct olLaunchHostFunctionTest : OffloadQueueTest {};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionTest);
struct olLaunchHostFunctionKernelTest : OffloadKernelTest {};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionKernelTest);
TEST_P(olLaunchHostFunctionTest, Success) {
ASSERT_SUCCESS(olLaunchHostFunction(Queue, [](void *) {}, nullptr));
}
TEST_P(olLaunchHostFunctionTest, SuccessSequence) {
uint32_t Buff[16] = {1, 1};
for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
ASSERT_SUCCESS(olLaunchHostFunction(
Queue,
[](void *BuffPtr) {
uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
AsU32[0] = AsU32[-1] + AsU32[-2];
},
BuffPtr));
}
ASSERT_SUCCESS(olSyncQueue(Queue));
for (uint32_t i = 2; i < 16; i++) {
ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
}
}
TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
// Verify that a host kernel can block execution - A host task is created that
// only resolves when Block is set to false.
ol_kernel_launch_size_args_t LaunchArgs;
LaunchArgs.Dimensions = 1;
LaunchArgs.GroupSize = {64, 1, 1};
LaunchArgs.NumGroups = {1, 1, 1};
LaunchArgs.DynSharedMemory = 0;
ol_queue_handle_t Queue;
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
void *Mem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
uint32_t *Data = (uint32_t *)Mem;
for (uint32_t i = 0; i < 64; i++) {
Data[i] = 0;
}
volatile bool Block = true;
ASSERT_SUCCESS(olLaunchHostFunction(
Queue,
[](void *Ptr) {
volatile bool *Block =
reinterpret_cast<volatile bool *>(reinterpret_cast<bool *>(Ptr));
while (*Block)
std::this_thread::yield();
},
const_cast<bool *>(&Block)));
struct {
void *Mem;
} Args{Mem};
ASSERT_SUCCESS(
olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), &LaunchArgs));
std::this_thread::sleep_for(std::chrono::milliseconds(500));
for (uint32_t i = 0; i < 64; i++) {
ASSERT_EQ(Data[i], 0);
}
Block = false;
ASSERT_SUCCESS(olSyncQueue(Queue));
for (uint32_t i = 0; i < 64; i++) {
ASSERT_EQ(Data[i], i);
}
ASSERT_SUCCESS(olDestroyQueue(Queue));
ASSERT_SUCCESS(olMemFree(Mem));
}
TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
olLaunchHostFunction(Queue, nullptr, nullptr));
}
TEST_P(olLaunchHostFunctionTest, InvalidNullQueue) {
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
olLaunchHostFunction(nullptr, [](void *) {}, nullptr));
}