diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 7114dad020e3..6fc75ac15428 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -621,9 +621,9 @@ struct AMDGPUSignalTy { } /// Wait until the signal gets a zero value. - Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr, + Error wait(const uint64_t ActiveTimeout = 0, GenericDeviceTy *Device = nullptr) const { - if (ActiveTimeout && !RPCServer) { + if (ActiveTimeout) { hsa_signal_value_t Got = 1; Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0, ActiveTimeout, HSA_WAIT_STATE_ACTIVE); @@ -632,14 +632,11 @@ struct AMDGPUSignalTy { } // If there is an RPC device attached to this stream we run it as a server. - uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX; - auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED; + uint64_t Timeout = UINT64_MAX; + auto WaitState = HSA_WAIT_STATE_BLOCKED; while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0, - Timeout, WaitState) != 0) { - if (RPCServer && Device) - if (auto Err = RPCServer->runServer(*Device)) - return Err; - } + Timeout, WaitState) != 0) + ; return Plugin::success(); } @@ -1052,11 +1049,6 @@ private: /// operation that was already finalized in a previous stream sycnhronize. uint32_t SyncCycle; - /// A pointer associated with an RPC server running on the given device. If - /// RPC is not being used this will be a null pointer. Otherwise, this - /// indicates that an RPC server is expected to be run on this stream. - RPCServerTy *RPCServer; - /// Mutex to protect stream's management. mutable std::mutex Mutex; @@ -1236,9 +1228,6 @@ public: /// Deinitialize the stream's signals. Error deinit() { return Plugin::success(); } - /// Attach an RPC server to this stream. - void setRPCServer(RPCServerTy *Server) { RPCServer = Server; } - /// Push a asynchronous kernel to the stream. The kernel arguments must be /// placed in a special allocation for kernel args and must keep alive until /// the kernel finalizes. Once the kernel is finished, the stream will release @@ -1266,10 +1255,30 @@ public: if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager)) return Err; + // If we are running an RPC server we want to wake up the server thread + // whenever there is a kernel running and let it sleep otherwise. + if (Device.getRPCServer()) + Device.Plugin.getRPCServer().Thread->notify(); + // Push the kernel with the output signal and an input signal (optional) - return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks, - GroupSize, StackSize, OutputSignal, - InputSignal); + if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, + NumBlocks, GroupSize, StackSize, + OutputSignal, InputSignal)) + return Err; + + // Register a callback to indicate when the kernel is complete. + if (Device.getRPCServer()) { + if (auto Err = Slots[Curr].schedCallback( + [](void *Data) -> llvm::Error { + GenericPluginTy &Plugin = + *reinterpret_cast(Data); + Plugin.getRPCServer().Thread->finish(); + return Error::success(); + }, + &Device.Plugin)) + return Err; + } + return Plugin::success(); } /// Push an asynchronous memory copy between pinned memory buffers. @@ -1479,8 +1488,8 @@ public: return Plugin::success(); // Wait until all previous operations on the stream have completed. - if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, - RPCServer, &Device)) + if (auto Err = + Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device)) return Err; // Reset the stream and perform all pending post actions. @@ -3027,7 +3036,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device) : Agent(Device.getAgent()), Queue(nullptr), SignalManager(Device.getSignalManager()), Device(Device), // Initialize the std::deque with some empty positions. - Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr), + Slots(32), NextSlot(0), SyncCycle(0), StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()), UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {} @@ -3383,10 +3392,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice, if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream)) return Err; - // If this kernel requires an RPC server we attach its pointer to the stream. - if (GenericDevice.getRPCServer()) - Stream->setRPCServer(GenericDevice.getRPCServer()); - // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used. if (ImplArgs && getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) { diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h index 5b9b7ffd086b..f3a8e7555020 100644 --- a/offload/plugins-nextgen/common/include/RPC.h +++ b/offload/plugins-nextgen/common/include/RPC.h @@ -19,7 +19,11 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Error.h" +#include +#include #include +#include +#include namespace llvm::omp::target { namespace plugin { @@ -37,6 +41,12 @@ public: /// Initializes the handles to the number of devices we may need to service. RPCServerTy(plugin::GenericPluginTy &Plugin); + /// Deinitialize the associated memory and resources. + llvm::Error shutDown(); + + /// Initialize the worker thread. + llvm::Error startThread(); + /// Check if this device image is using an RPC server. This checks for the /// precense of an externally visible symbol in the device image that will /// be present whenever RPC code is called. @@ -51,17 +61,77 @@ public: plugin::GenericGlobalHandlerTy &Handler, plugin::DeviceImageTy &Image); - /// Runs the RPC server associated with the \p Device until the pending work - /// is cleared. - llvm::Error runServer(plugin::GenericDeviceTy &Device); - /// Deinitialize the RPC server for the given device. This will free the /// memory associated with the k llvm::Error deinitDevice(plugin::GenericDeviceTy &Device); private: /// Array from this device's identifier to its attached devices. - llvm::SmallVector Buffers; + std::unique_ptr Buffers; + + /// Array of associated devices. These must be alive as long as the server is. + std::unique_ptr Devices; + + /// A helper class for running the user thread that handles the RPC interface. + /// Because we only need to check the RPC server while any kernels are + /// working, we track submission / completion events to allow the thread to + /// sleep when it is not needed. + struct ServerThread { + std::thread Worker; + + /// A boolean indicating whether or not the worker thread should continue. + std::atomic Running; + + /// The number of currently executing kernels across all devices that need + /// the server thread to be running. + std::atomic NumUsers; + + /// The condition variable used to suspend the thread if no work is needed. + std::condition_variable CV; + std::mutex Mutex; + + /// A reference to all the RPC interfaces that the server is handling. + llvm::ArrayRef Buffers; + + /// A reference to the associated generic device for the buffer. + llvm::ArrayRef Devices; + + /// Initialize the worker thread to run in the background. + ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[], + size_t Length) + : Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length), + Devices(Devices, Length) {} + + ~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); } + + /// Notify the worker thread that there is a user that needs it. + void notify() { + std::lock_guard Lock(Mutex); + NumUsers.fetch_add(1, std::memory_order_relaxed); + CV.notify_all(); + } + + /// Indicate that one of the dependent users has finished. + void finish() { + [[maybe_unused]] uint32_t Old = + NumUsers.fetch_sub(1, std::memory_order_relaxed); + assert(Old > 0 && "Attempt to signal finish with no pending work"); + } + + /// Destroy the worker thread and wait. + void shutDown(); + + /// Initialize the worker thread. + void startThread(); + + /// Run the server thread to continuously check the RPC interface for work + /// to be done for the device. + void run(); + }; + +public: + /// Pointer to the server thread instance. + std::unique_ptr Thread; }; } // namespace llvm::omp::target diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index a164bfb51d02..c9acabea6977 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1057,6 +1057,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin, if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image)) return Err; + if (auto Err = Server.startThread()) + return Err; + RPCServer = &Server; DP("Running an RPC server on device %d\n", getDeviceId()); return Plugin::success(); @@ -1630,8 +1633,11 @@ Error GenericPluginTy::deinit() { if (GlobalHandler) delete GlobalHandler; - if (RPCServer) + if (RPCServer) { + if (Error Err = RPCServer->shutDown()) + return Err; delete RPCServer; + } if (RecordReplay) delete RecordReplay; diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp index f20c8f7bcc5c..81ad9ca66808 100644 --- a/offload/plugins-nextgen/common/src/RPC.cpp +++ b/offload/plugins-nextgen/common/src/RPC.cpp @@ -21,8 +21,8 @@ using namespace omp; using namespace target; template -rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device, - rpc::Server::Port &Port) { +rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, + rpc::Server::Port &Port) { switch (Port.get_opcode()) { case LIBC_MALLOC: { @@ -62,21 +62,99 @@ rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device, return rpc::RPC_SUCCESS; } -static rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device, - rpc::Server::Port &Port, - uint32_t NumLanes) { +static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, + rpc::Server::Port &Port, + uint32_t NumLanes) { if (NumLanes == 1) - return handle_offload_opcodes<1>(Device, Port); + return handleOffloadOpcodes<1>(Device, Port); else if (NumLanes == 32) - return handle_offload_opcodes<32>(Device, Port); + return handleOffloadOpcodes<32>(Device, Port); else if (NumLanes == 64) - return handle_offload_opcodes<64>(Device, Port); + return handleOffloadOpcodes<64>(Device, Port); else return rpc::RPC_ERROR; } +static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { + uint64_t NumPorts = + std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); + rpc::Server Server(NumPorts, Buffer); + + auto Port = Server.try_open(Device.getWarpSize()); + if (!Port) + return rpc::RPC_SUCCESS; + + rpc::Status Status = + handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); + + // Let the `libc` library handle any other unhandled opcodes. +#ifdef LIBOMPTARGET_RPC_SUPPORT + if (Status == rpc::RPC_UNHANDLED_OPCODE) + Status = handle_libc_opcodes(*Port, Device.getWarpSize()); +#endif + + Port->close(); + + return Status; +} + +void RPCServerTy::ServerThread::startThread() { + Worker = std::thread([this]() { run(); }); +} + +void RPCServerTy::ServerThread::shutDown() { + { + std::lock_guard Lock(Mutex); + Running.store(false, std::memory_order_release); + CV.notify_all(); + } + if (Worker.joinable()) + Worker.join(); +} + +void RPCServerTy::ServerThread::run() { + std::unique_lock Lock(Mutex); + for (;;) { + CV.wait(Lock, [&]() { + return NumUsers.load(std::memory_order_acquire) > 0 || + !Running.load(std::memory_order_acquire); + }); + + if (!Running.load(std::memory_order_acquire)) + return; + + Lock.unlock(); + while (NumUsers.load(std::memory_order_relaxed) > 0 && + Running.load(std::memory_order_relaxed)) { + for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { + if (!Buffer || !Device) + continue; + + // If running the server failed, print a message but keep running. + if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) + FAILURE_MESSAGE("Unhandled or invalid RPC opcode!"); + } + } + Lock.lock(); + } +} + RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) - : Buffers(Plugin.getNumDevices()) {} + : Buffers(std::make_unique(Plugin.getNumDevices())), + Devices(std::make_unique( + Plugin.getNumDevices())), + Thread(new ServerThread(Buffers.get(), Devices.get(), + Plugin.getNumDevices())) {} + +llvm::Error RPCServerTy::startThread() { + Thread->startThread(); + return Error::success(); +} + +llvm::Error RPCServerTy::shutDown() { + Thread->shutDown(); + return Error::success(); +} llvm::Expected RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, @@ -108,35 +186,14 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, sizeof(rpc::Client), nullptr)) return Err; Buffers[Device.getDeviceId()] = RPCBuffer; - - return Error::success(); -} - -Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { - uint64_t NumPorts = - std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); - rpc::Server Server(NumPorts, Buffers[Device.getDeviceId()]); - - auto Port = Server.try_open(Device.getWarpSize()); - if (!Port) - return Error::success(); - - int Status = handle_offload_opcodes(Device, *Port, Device.getWarpSize()); - - // Let the `libc` library handle any other unhandled opcodes. -#ifdef LIBOMPTARGET_RPC_SUPPORT - if (Status == rpc::RPC_UNHANDLED_OPCODE) - Status = handle_libc_opcodes(*Port, Device.getWarpSize()); -#endif - - Port->close(); - if (Status != rpc::RPC_SUCCESS) - return createStringError("RPC server given invalid opcode!"); + Devices[Device.getDeviceId()] = &Device; return Error::success(); } Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); + Buffers[Device.getDeviceId()] = nullptr; + Devices[Device.getDeviceId()] = nullptr; return Error::success(); } diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp index 5ec3adb9e4e3..7878499dbfcb 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp @@ -63,6 +63,7 @@ DLWRAP(cuStreamCreate, 2) DLWRAP(cuStreamDestroy, 1) DLWRAP(cuStreamSynchronize, 1) DLWRAP(cuStreamQuery, 1) +DLWRAP(cuStreamAddCallback, 4) DLWRAP(cuCtxSetCurrent, 1) DLWRAP(cuDevicePrimaryCtxRelease, 1) DLWRAP(cuDevicePrimaryCtxGetState, 3) diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h index 16c8f7ad46c4..ad874735a25e 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h @@ -286,6 +286,8 @@ static inline void *CU_LAUNCH_PARAM_END = (void *)0x00; static inline void *CU_LAUNCH_PARAM_BUFFER_POINTER = (void *)0x01; static inline void *CU_LAUNCH_PARAM_BUFFER_SIZE = (void *)0x02; +typedef void (*CUstreamCallback)(CUstream, CUresult, void *); + CUresult cuCtxGetDevice(CUdevice *); CUresult cuDeviceGet(CUdevice *, int); CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice); @@ -326,6 +328,7 @@ CUresult cuStreamCreate(CUstream *, unsigned); CUresult cuStreamDestroy(CUstream); CUresult cuStreamSynchronize(CUstream); CUresult cuStreamQuery(CUstream); +CUresult cuStreamAddCallback(CUstream, CUstreamCallback, void *, unsigned int); CUresult cuCtxSetCurrent(CUcontext); CUresult cuDevicePrimaryCtxRelease(CUdevice); CUresult cuDevicePrimaryCtxGetState(CUdevice, unsigned *, int *); diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 894d1c2214b9..52e8a100dc87 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -628,17 +628,7 @@ struct CUDADeviceTy : public GenericDeviceTy { Error synchronizeImpl(__tgt_async_info &AsyncInfo) override { CUstream Stream = reinterpret_cast(AsyncInfo.Queue); CUresult Res; - // If we have an RPC server running on this device we will continuously - // query it for work rather than blocking. - if (!getRPCServer()) { - Res = cuStreamSynchronize(Stream); - } else { - do { - Res = cuStreamQuery(Stream); - if (auto Err = getRPCServer()->runServer(*this)) - return Err; - } while (Res == CUDA_ERROR_NOT_READY); - } + Res = cuStreamSynchronize(Stream); // Once the stream is synchronized, return it to stream pool and reset // AsyncInfo. This is to make sure the synchronization only works for its @@ -823,17 +813,6 @@ struct CUDADeviceTy : public GenericDeviceTy { if (auto Err = getStream(AsyncInfoWrapper, Stream)) return Err; - // If there is already pending work on the stream it could be waiting for - // someone to check the RPC server. - if (auto *RPCServer = getRPCServer()) { - CUresult Res = cuStreamQuery(Stream); - while (Res == CUDA_ERROR_NOT_READY) { - if (auto Err = RPCServer->runServer(*this)) - return Err; - Res = cuStreamQuery(Stream); - } - } - CUresult Res = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream); return Plugin::check(Res, "Error in cuMemcpyDtoHAsync: %s"); } @@ -1292,9 +1271,25 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice, reinterpret_cast(&LaunchParams.Size), CU_LAUNCH_PARAM_END}; + // If we are running an RPC server we want to wake up the server thread + // whenever there is a kernel running and let it sleep otherwise. + if (GenericDevice.getRPCServer()) + GenericDevice.Plugin.getRPCServer().Thread->notify(); + CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2], NumThreads[0], NumThreads[1], NumThreads[2], MaxDynCGroupMem, Stream, nullptr, Config); + + // Register a callback to indicate when the kernel is complete. + if (GenericDevice.getRPCServer()) + cuLaunchHostFunc( + Stream, + [](void *Data) { + GenericPluginTy &Plugin = *reinterpret_cast(Data); + Plugin.getRPCServer().Thread->finish(); + }, + &GenericDevice.Plugin); + return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName()); } diff --git a/offload/test/libc/server.c b/offload/test/libc/server.c new file mode 100644 index 000000000000..67f60a648235 --- /dev/null +++ b/offload/test/libc/server.c @@ -0,0 +1,56 @@ +// RUN: %libomptarget-compile-run-and-check-generic + +// REQUIRES: libc + +#include +#include +#include + +#pragma omp begin declare variant match(device = {kind(gpu)}) +// Extension provided by the 'libc' project. +unsigned long long __llvm_omp_host_call(void *fn, void *args, size_t size); +#pragma omp declare target to(__llvm_omp_host_call) device_type(nohost) +#pragma omp end declare variant + +#pragma omp begin declare variant match(device = {kind(cpu)}) +// Dummy host implementation to make this work for all targets. +unsigned long long __llvm_omp_host_call(void *fn, void *args, size_t size) { + return ((unsigned long long (*)(void *))fn)(args); +} +#pragma omp end declare variant + +long long foo(void *data) { return -1; } + +void *fn_ptr = NULL; +#pragma omp declare target to(fn_ptr) + +int main() { + fn_ptr = (void *)&foo; +#pragma omp target update to(fn_ptr) + + for (int i = 0; i < 4; ++i) { +#pragma omp target + { + long long res = __llvm_omp_host_call(fn_ptr, NULL, 0); + assert(res == -1 && "RPC call failed\n"); + } + + for (int j = 0; j < 128; ++j) { +#pragma omp target nowait + { + long long res = __llvm_omp_host_call(fn_ptr, NULL, 0); + assert(res == -1 && "RPC call failed\n"); + } + } +#pragma omp taskwait + +#pragma omp target + { + long long res = __llvm_omp_host_call(fn_ptr, NULL, 0); + assert(res == -1 && "RPC call failed\n"); + } + } + + // CHECK: PASS + puts("PASS"); +}