[Offload] Make olLaunchKernel test thread safe (#149497)
This sprinkles a few mutexes around the plugin interface so that the olLaunchKernel CTS test now passes when ran on multiple threads. Part of this also involved changing the interface for device synchronise so that it can optionally not free the underlying queue (which introduced a race condition in liboffload).
This commit is contained in:
parent
24ea1559d3
commit
910d7e90bf
@ -21,6 +21,7 @@
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
|
||||
extern "C" {
|
||||
|
||||
@ -76,6 +77,9 @@ struct __tgt_async_info {
|
||||
/// should be freed after finalization.
|
||||
llvm::SmallVector<void *, 2> AssociatedAllocations;
|
||||
|
||||
/// Mutex to guard access to AssociatedAllocations and the Queue.
|
||||
std::mutex Mutex;
|
||||
|
||||
/// The kernel launch environment used to issue a kernel. Stored here to
|
||||
/// ensure it is a valid location while the transfer to the device is
|
||||
/// happening.
|
||||
|
@ -208,7 +208,7 @@ Error initPlugins(OffloadContext &Context) {
|
||||
}
|
||||
|
||||
Error olInit_impl() {
|
||||
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
|
||||
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
||||
|
||||
if (isOffloadInitialized()) {
|
||||
OffloadContext::get().RefCount++;
|
||||
@ -226,7 +226,7 @@ Error olInit_impl() {
|
||||
}
|
||||
|
||||
Error olShutDown_impl() {
|
||||
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
|
||||
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
||||
|
||||
if (--OffloadContext::get().RefCount != 0)
|
||||
return Error::success();
|
||||
@ -487,16 +487,13 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
|
||||
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
||||
// on it, but we have nothing to synchronize in that situation anyway.
|
||||
if (Queue->AsyncInfo->Queue) {
|
||||
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
|
||||
// We don't need to release the queue and we would like the ability for
|
||||
// other offload threads to submit work concurrently, so pass "false" here
|
||||
// so we don't release the underlying queue object.
|
||||
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
|
||||
return Err;
|
||||
}
|
||||
|
||||
// Recreate the stream resource so the queue can be reused
|
||||
// TODO: Would be easier for the synchronization to (optionally) not release
|
||||
// it to begin with.
|
||||
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
|
||||
return Res;
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
@ -747,7 +744,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
|
||||
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
||||
auto &Device = Program->Image->getDevice();
|
||||
|
||||
std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
|
||||
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
|
||||
|
||||
switch (Kind) {
|
||||
case OL_SYMBOL_KIND_KERNEL: {
|
||||
|
@ -116,7 +116,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
|
||||
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
|
||||
const char *RegionName) {
|
||||
assert(PM && "Runtime not initialized");
|
||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
|
||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
|
||||
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");
|
||||
|
||||
TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: Data Copy",
|
||||
@ -311,7 +311,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
|
||||
int32_t ThreadLimit, void *HostPtr,
|
||||
KernelArgsTy *KernelArgs) {
|
||||
assert(PM && "Runtime not initialized");
|
||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
|
||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
|
||||
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
|
||||
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
|
||||
"\n",
|
||||
|
@ -2232,16 +2232,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
||||
/// Get the stream of the asynchronous info structure or get a new one.
|
||||
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
|
||||
AMDGPUStreamTy *&Stream) {
|
||||
// Get the stream (if any) from the async info.
|
||||
Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
|
||||
if (!Stream) {
|
||||
// There was no stream; get an idle one.
|
||||
if (auto Err = AMDGPUStreamManager.getResource(Stream))
|
||||
return Err;
|
||||
|
||||
// Modify the async info's stream.
|
||||
AsyncInfoWrapper.setQueueAs<AMDGPUStreamTy *>(Stream);
|
||||
}
|
||||
auto WrapperStream =
|
||||
AsyncInfoWrapper.getOrInitQueue<AMDGPUStreamTy *>(AMDGPUStreamManager);
|
||||
if (!WrapperStream)
|
||||
return WrapperStream.takeError();
|
||||
Stream = *WrapperStream;
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
@ -2296,7 +2291,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
||||
}
|
||||
|
||||
/// Synchronize current thread with the pending operations on the async info.
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||
bool ReleaseQueue) override {
|
||||
AMDGPUStreamTy *Stream =
|
||||
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
|
||||
assert(Stream && "Invalid stream");
|
||||
@ -2307,8 +2303,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
||||
// Once the stream is synchronized, return it to stream pool and reset
|
||||
// AsyncInfo. This is to make sure the synchronization only works for its
|
||||
// own tasks.
|
||||
AsyncInfo.Queue = nullptr;
|
||||
return AMDGPUStreamManager.returnResource(Stream);
|
||||
if (ReleaseQueue) {
|
||||
AsyncInfo.Queue = nullptr;
|
||||
return AMDGPUStreamManager.returnResource(Stream);
|
||||
}
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
/// Query for the completion of the pending operations on the async info.
|
||||
|
@ -60,6 +60,7 @@ struct GenericPluginTy;
|
||||
struct GenericKernelTy;
|
||||
struct GenericDeviceTy;
|
||||
struct RecordReplayTy;
|
||||
template <typename ResourceRef> class GenericDeviceResourceManagerTy;
|
||||
|
||||
namespace Plugin {
|
||||
/// Create a success error. This is the same as calling Error::success(), but
|
||||
@ -127,6 +128,20 @@ struct AsyncInfoWrapperTy {
|
||||
AsyncInfoPtr->Queue = Queue;
|
||||
}
|
||||
|
||||
/// Get the queue, using the provided resource manager to initialise it if it
|
||||
/// doesn't exist.
|
||||
template <typename Ty, typename RMTy>
|
||||
Expected<Ty>
|
||||
getOrInitQueue(GenericDeviceResourceManagerTy<RMTy> &ResourceManager) {
|
||||
std::lock_guard<std::mutex> Lock(AsyncInfoPtr->Mutex);
|
||||
if (!AsyncInfoPtr->Queue) {
|
||||
if (auto Err = ResourceManager.getResource(
|
||||
*reinterpret_cast<Ty *>(&AsyncInfoPtr->Queue)))
|
||||
return Err;
|
||||
}
|
||||
return getQueueAs<Ty>();
|
||||
}
|
||||
|
||||
/// Synchronize with the __tgt_async_info's pending operations if it's the
|
||||
/// internal async info. The error associated to the asynchronous operations
|
||||
/// issued in this queue must be provided in \p Err. This function will update
|
||||
@ -138,6 +153,7 @@ struct AsyncInfoWrapperTy {
|
||||
/// Register \p Ptr as an associated allocation that is freed after
|
||||
/// finalization.
|
||||
void freeAllocationAfterSynchronization(void *Ptr) {
|
||||
std::lock_guard<std::mutex> AllocationGuard(AsyncInfoPtr->Mutex);
|
||||
AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
|
||||
}
|
||||
|
||||
@ -827,9 +843,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
|
||||
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
|
||||
|
||||
/// Synchronize the current thread with the pending operations on the
|
||||
/// __tgt_async_info structure.
|
||||
Error synchronize(__tgt_async_info *AsyncInfo);
|
||||
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;
|
||||
/// __tgt_async_info structure. If ReleaseQueue is false, then the
|
||||
// underlying queue will not be released. In this case, additional
|
||||
// work may be submitted to the queue whilst a synchronize is running.
|
||||
Error synchronize(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true);
|
||||
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||
bool ReleaseQueue) = 0;
|
||||
|
||||
/// Invokes any global constructors on the device if present and is required
|
||||
/// by the target.
|
||||
|
@ -1335,18 +1335,25 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
|
||||
return eraseEntry(*Entry);
|
||||
}
|
||||
|
||||
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
|
||||
if (!AsyncInfo || !AsyncInfo->Queue)
|
||||
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
|
||||
"invalid async info queue");
|
||||
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
|
||||
bool ReleaseQueue) {
|
||||
SmallVector<void *> AllocsToDelete{};
|
||||
{
|
||||
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};
|
||||
|
||||
if (auto Err = synchronizeImpl(*AsyncInfo))
|
||||
return Err;
|
||||
if (!AsyncInfo || !AsyncInfo->Queue)
|
||||
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
|
||||
"invalid async info queue");
|
||||
|
||||
for (auto *Ptr : AsyncInfo->AssociatedAllocations)
|
||||
if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
|
||||
return Err;
|
||||
|
||||
std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
|
||||
}
|
||||
|
||||
for (auto *Ptr : AllocsToDelete)
|
||||
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
|
||||
return Err;
|
||||
AsyncInfo->AssociatedAllocations.clear();
|
||||
|
||||
return Plugin::success();
|
||||
}
|
||||
|
@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy {
|
||||
|
||||
/// Get the stream of the asynchronous info structure or get a new one.
|
||||
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
|
||||
// Get the stream (if any) from the async info.
|
||||
Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
|
||||
if (!Stream) {
|
||||
// There was no stream; get an idle one.
|
||||
if (auto Err = CUDAStreamManager.getResource(Stream))
|
||||
return Err;
|
||||
|
||||
// Modify the async info's stream.
|
||||
AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
|
||||
}
|
||||
auto WrapperStream =
|
||||
AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager);
|
||||
if (!WrapperStream)
|
||||
return WrapperStream.takeError();
|
||||
Stream = *WrapperStream;
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
|
||||
}
|
||||
|
||||
/// Synchronize current thread with the pending operations on the async info.
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||
bool ReleaseQueue) override {
|
||||
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
|
||||
CUresult Res;
|
||||
Res = cuStreamSynchronize(Stream);
|
||||
|
||||
// Once the stream is synchronized, return it to stream pool and reset
|
||||
// AsyncInfo. This is to make sure the synchronization only works for its
|
||||
// own tasks.
|
||||
AsyncInfo.Queue = nullptr;
|
||||
if (auto Err = CUDAStreamManager.returnResource(Stream))
|
||||
return Err;
|
||||
// Once the stream is synchronized and we want to release the queue, return
|
||||
// it to stream pool and reset AsyncInfo. This is to make sure the
|
||||
// synchronization only works for its own tasks.
|
||||
if (ReleaseQueue) {
|
||||
AsyncInfo.Queue = nullptr;
|
||||
if (auto Err = CUDAStreamManager.returnResource(Stream))
|
||||
return Err;
|
||||
}
|
||||
|
||||
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
|
||||
}
|
||||
|
@ -297,7 +297,8 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
|
||||
|
||||
/// All functions are already synchronous. No need to do anything on this
|
||||
/// synchronization function.
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||
bool ReleaseQueue) override {
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <OffloadAPI.h>
|
||||
#include <OffloadPrint.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include <thread>
|
||||
|
||||
#include "Environment.hpp"
|
||||
|
||||
@ -57,6 +58,23 @@ inline std::string SanitizeString(const std::string &Str) {
|
||||
return NewStr;
|
||||
}
|
||||
|
||||
template <typename Fn> inline void threadify(Fn body) {
|
||||
std::vector<std::thread> Threads;
|
||||
for (size_t I = 0; I < 20; I++) {
|
||||
Threads.emplace_back(
|
||||
[&body](size_t I) {
|
||||
std::string ScopeMsg{"Thread #"};
|
||||
ScopeMsg.append(std::to_string(I));
|
||||
SCOPED_TRACE(ScopeMsg);
|
||||
body(I);
|
||||
},
|
||||
I);
|
||||
}
|
||||
for (auto &T : Threads) {
|
||||
T.join();
|
||||
}
|
||||
}
|
||||
|
||||
struct OffloadTest : ::testing::Test {
|
||||
ol_device_handle_t Host = TestEnvironment::getHostDevice();
|
||||
};
|
||||
|
@ -104,6 +104,29 @@ TEST_P(olLaunchKernelFooTest, Success) {
|
||||
ASSERT_SUCCESS(olMemFree(Mem));
|
||||
}
|
||||
|
||||
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
|
||||
threadify([&](size_t) {
|
||||
void *Mem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
|
||||
struct {
|
||||
void *Mem;
|
||||
} Args{Mem};
|
||||
|
||||
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
|
||||
&LaunchArgs));
|
||||
|
||||
ASSERT_SUCCESS(olSyncQueue(Queue));
|
||||
|
||||
uint32_t *Data = (uint32_t *)Mem;
|
||||
for (uint32_t i = 0; i < 64; i++) {
|
||||
ASSERT_EQ(Data[i], i);
|
||||
}
|
||||
|
||||
ASSERT_SUCCESS(olMemFree(Mem));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_P(olLaunchKernelNoArgsTest, Success) {
|
||||
ASSERT_SUCCESS(
|
||||
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));
|
||||
|
Loading…
x
Reference in New Issue
Block a user