[Offload] olLaunchHostFunction
(#152482)
Add an `olLaunchHostFunction` method that allows enqueueing host work to the stream.
This commit is contained in:
parent
598562077a
commit
30c7951136
@ -31,6 +31,13 @@ class IsHandleType<string Type> {
|
||||
!ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
|
||||
}
|
||||
|
||||
// Does the type end with '_cb_t'?
|
||||
class IsCallbackType<string Type> {
|
||||
// size("_cb_t") == 5
|
||||
bit ret = !if(!lt(!size(Type), 5), 0,
|
||||
!ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
|
||||
}
|
||||
|
||||
// Does the type end with '*'?
|
||||
class IsPointerType<string Type> {
|
||||
bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
|
||||
@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
|
||||
TypeInfo type_info = TypeInfo<"", "">;
|
||||
bit IsHandle = IsHandleType<type>.ret;
|
||||
bit IsPointer = IsPointerType<type>.ret;
|
||||
bit IsCallback = IsCallbackType<type>.ret;
|
||||
}
|
||||
|
||||
// A parameter whose range is described by other parameters in the function.
|
||||
@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
|
||||
}
|
||||
|
||||
class ShouldCheckPointer<Param P> {
|
||||
bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
|
||||
bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
|
||||
}
|
||||
|
||||
// For a list of returns that contains a specific return code, find and append
|
||||
|
@ -108,3 +108,29 @@ def : Function {
|
||||
Return<"OL_ERRC_INVALID_QUEUE">
|
||||
];
|
||||
}
|
||||
|
||||
def : FptrTypedef {
|
||||
let name = "ol_host_function_cb_t";
|
||||
let desc = "Host function for use by `olLaunchHostFunction`.";
|
||||
let params = [
|
||||
Param<"void *", "UserData", "user specified data passed into `olLaunchHostFunction`.", PARAM_IN>,
|
||||
];
|
||||
let return = "void";
|
||||
}
|
||||
|
||||
def : Function {
|
||||
let name = "olLaunchHostFunction";
|
||||
let desc = "Enqueue a callback function on the host.";
|
||||
let details = [
|
||||
"The provided function will be called from the same process as the one that called `olLaunchHostFunction`.",
|
||||
"The callback will not run until all previous work submitted to the queue has completed.",
|
||||
"The callback must return before any work submitted to the queue after it is started.",
|
||||
"The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
|
||||
];
|
||||
let params = [
|
||||
Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
|
||||
Param<"ol_host_function_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
|
||||
Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
|
||||
];
|
||||
let returns = [];
|
||||
}
|
||||
|
@ -833,5 +833,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
|
||||
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
|
||||
}
|
||||
|
||||
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
|
||||
ol_host_function_cb_t Callback,
|
||||
void *UserData) {
|
||||
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
|
||||
Queue->AsyncInfo);
|
||||
}
|
||||
|
||||
} // namespace offload
|
||||
} // namespace llvm
|
||||
|
@ -1063,6 +1063,20 @@ private:
|
||||
/// Indicate to spread data transfers across all available SDMAs
|
||||
bool UseMultipleSdmaEngines;
|
||||
|
||||
/// Wrapper function for implementing host callbacks
|
||||
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
|
||||
AMDGPUSignalTy *OutputSignal,
|
||||
void (*Callback)(void *), void *UserData) {
|
||||
// The wait call will not error in this context.
|
||||
if (InputSignal)
|
||||
if (auto Err = InputSignal->wait())
|
||||
reportFatalInternalError(std::move(Err));
|
||||
|
||||
Callback(UserData);
|
||||
|
||||
OutputSignal->signal();
|
||||
}
|
||||
|
||||
/// Return the current number of asynchronous operations on the stream.
|
||||
uint32_t size() const { return NextSlot; }
|
||||
|
||||
@ -1495,6 +1509,31 @@ public:
|
||||
OutputSignal->get());
|
||||
}
|
||||
|
||||
Error pushHostCallback(void (*Callback)(void *), void *UserData) {
|
||||
// Retrieve an available signal for the operation's output.
|
||||
AMDGPUSignalTy *OutputSignal = nullptr;
|
||||
if (auto Err = SignalManager.getResource(OutputSignal))
|
||||
return Err;
|
||||
OutputSignal->reset();
|
||||
OutputSignal->increaseUseCount();
|
||||
|
||||
AMDGPUSignalTy *InputSignal;
|
||||
{
|
||||
std::lock_guard<std::mutex> Lock(Mutex);
|
||||
|
||||
// Consume stream slot and compute dependencies.
|
||||
InputSignal = consume(OutputSignal).second;
|
||||
}
|
||||
|
||||
// "Leaking" the thread here is consistent with other work added to the
|
||||
// queue. The input and output signals will remain valid until the output is
|
||||
// signaled.
|
||||
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
|
||||
.detach();
|
||||
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
/// Synchronize with the stream. The current thread waits until all operations
|
||||
/// are finalized and it performs the pending post actions (i.e., releasing
|
||||
/// intermediate buffers).
|
||||
@ -2553,6 +2592,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
|
||||
AsyncInfoWrapperTy &AsyncInfo) override {
|
||||
AMDGPUStreamTy *Stream = nullptr;
|
||||
if (auto Err = getStream(AsyncInfo, Stream))
|
||||
return Err;
|
||||
|
||||
return Stream->pushHostCallback(Callback, UserData);
|
||||
};
|
||||
|
||||
/// Create an event.
|
||||
Error createEventImpl(void **EventPtrStorage) override {
|
||||
AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);
|
||||
|
@ -965,6 +965,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
|
||||
Error initDeviceInfo(__tgt_device_info *DeviceInfo);
|
||||
virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;
|
||||
|
||||
/// Enqueue a host call to AsyncInfo
|
||||
Error enqueueHostCall(void (*Callback)(void *), void *UserData,
|
||||
__tgt_async_info *AsyncInfo);
|
||||
virtual Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
|
||||
AsyncInfoWrapperTy &AsyncInfo) = 0;
|
||||
|
||||
/// Create an event.
|
||||
Error createEvent(void **EventPtrStorage);
|
||||
virtual Error createEventImpl(void **EventPtrStorage) = 0;
|
||||
|
@ -1589,6 +1589,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
|
||||
return Err;
|
||||
}
|
||||
|
||||
Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData,
|
||||
__tgt_async_info *AsyncInfo) {
|
||||
AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
|
||||
|
||||
auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper);
|
||||
AsyncInfoWrapper.finalize(Err);
|
||||
return Err;
|
||||
}
|
||||
|
||||
Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
|
||||
assert(DeviceInfo && "Invalid device info");
|
||||
|
||||
|
@ -873,6 +873,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
|
||||
AsyncInfoWrapperTy &AsyncInfo) override {
|
||||
if (auto Err = setContext())
|
||||
return Err;
|
||||
|
||||
CUstream Stream;
|
||||
if (auto Err = getStream(AsyncInfo, Stream))
|
||||
return Err;
|
||||
|
||||
CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
|
||||
return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
|
||||
};
|
||||
|
||||
/// Create an event.
|
||||
Error createEventImpl(void **EventPtrStorage) override {
|
||||
CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
|
||||
|
@ -320,6 +320,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
|
||||
"initDeviceInfoImpl not supported");
|
||||
}
|
||||
|
||||
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
|
||||
AsyncInfoWrapperTy &AsyncInfo) override {
|
||||
Callback(UserData);
|
||||
return Plugin::success();
|
||||
};
|
||||
|
||||
/// This plugin does not support the event API. Do nothing without failing.
|
||||
Error createEventImpl(void **EventPtrStorage) override {
|
||||
*EventPtrStorage = nullptr;
|
||||
|
@ -41,7 +41,8 @@ add_offload_unittest("queue"
|
||||
queue/olDestroyQueue.cpp
|
||||
queue/olGetQueueInfo.cpp
|
||||
queue/olGetQueueInfoSize.cpp
|
||||
queue/olWaitEvents.cpp)
|
||||
queue/olWaitEvents.cpp
|
||||
queue/olLaunchHostFunction.cpp)
|
||||
|
||||
add_offload_unittest("symbol"
|
||||
symbol/olGetSymbol.cpp
|
||||
|
107
offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
Normal file
107
offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
//===------- Offload API tests - olLaunchHostFunction ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../common/Fixtures.hpp"
|
||||
#include <OffloadAPI.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <thread>
|
||||
|
||||
struct olLaunchHostFunctionTest : OffloadQueueTest {};
|
||||
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionTest);
|
||||
|
||||
struct olLaunchHostFunctionKernelTest : OffloadKernelTest {};
|
||||
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionKernelTest);
|
||||
|
||||
TEST_P(olLaunchHostFunctionTest, Success) {
|
||||
ASSERT_SUCCESS(olLaunchHostFunction(Queue, [](void *) {}, nullptr));
|
||||
}
|
||||
|
||||
TEST_P(olLaunchHostFunctionTest, SuccessSequence) {
|
||||
uint32_t Buff[16] = {1, 1};
|
||||
|
||||
for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
|
||||
ASSERT_SUCCESS(olLaunchHostFunction(
|
||||
Queue,
|
||||
[](void *BuffPtr) {
|
||||
uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
|
||||
AsU32[0] = AsU32[-1] + AsU32[-2];
|
||||
},
|
||||
BuffPtr));
|
||||
}
|
||||
|
||||
ASSERT_SUCCESS(olSyncQueue(Queue));
|
||||
|
||||
for (uint32_t i = 2; i < 16; i++) {
|
||||
ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
|
||||
// Verify that a host kernel can block execution - A host task is created that
|
||||
// only resolves when Block is set to false.
|
||||
ol_kernel_launch_size_args_t LaunchArgs;
|
||||
LaunchArgs.Dimensions = 1;
|
||||
LaunchArgs.GroupSize = {64, 1, 1};
|
||||
LaunchArgs.NumGroups = {1, 1, 1};
|
||||
LaunchArgs.DynSharedMemory = 0;
|
||||
|
||||
ol_queue_handle_t Queue;
|
||||
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
|
||||
|
||||
void *Mem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
|
||||
|
||||
uint32_t *Data = (uint32_t *)Mem;
|
||||
for (uint32_t i = 0; i < 64; i++) {
|
||||
Data[i] = 0;
|
||||
}
|
||||
|
||||
volatile bool Block = true;
|
||||
ASSERT_SUCCESS(olLaunchHostFunction(
|
||||
Queue,
|
||||
[](void *Ptr) {
|
||||
volatile bool *Block =
|
||||
reinterpret_cast<volatile bool *>(reinterpret_cast<bool *>(Ptr));
|
||||
|
||||
while (*Block)
|
||||
std::this_thread::yield();
|
||||
},
|
||||
const_cast<bool *>(&Block)));
|
||||
|
||||
struct {
|
||||
void *Mem;
|
||||
} Args{Mem};
|
||||
ASSERT_SUCCESS(
|
||||
olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), &LaunchArgs));
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
for (uint32_t i = 0; i < 64; i++) {
|
||||
ASSERT_EQ(Data[i], 0);
|
||||
}
|
||||
|
||||
Block = false;
|
||||
ASSERT_SUCCESS(olSyncQueue(Queue));
|
||||
|
||||
for (uint32_t i = 0; i < 64; i++) {
|
||||
ASSERT_EQ(Data[i], i);
|
||||
}
|
||||
|
||||
ASSERT_SUCCESS(olDestroyQueue(Queue));
|
||||
ASSERT_SUCCESS(olMemFree(Mem));
|
||||
}
|
||||
|
||||
TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {
|
||||
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
|
||||
olLaunchHostFunction(Queue, nullptr, nullptr));
|
||||
}
|
||||
|
||||
TEST_P(olLaunchHostFunctionTest, InvalidNullQueue) {
|
||||
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
|
||||
olLaunchHostFunction(nullptr, [](void *) {}, nullptr));
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user