[Offload] Make olLaunchKernel test thread safe (#149497)
This sprinkles a few mutexes around the plugin interface so that the olLaunchKernel CTS test now passes when ran on multiple threads. Part of this also involved changing the interface for device synchronise so that it can optionally not free the underlying queue (which introduced a race condition in liboffload).
This commit is contained in:
parent
24ea1559d3
commit
910d7e90bf
@ -21,6 +21,7 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
@ -76,6 +77,9 @@ struct __tgt_async_info {
|
|||||||
/// should be freed after finalization.
|
/// should be freed after finalization.
|
||||||
llvm::SmallVector<void *, 2> AssociatedAllocations;
|
llvm::SmallVector<void *, 2> AssociatedAllocations;
|
||||||
|
|
||||||
|
/// Mutex to guard access to AssociatedAllocations and the Queue.
|
||||||
|
std::mutex Mutex;
|
||||||
|
|
||||||
/// The kernel launch environment used to issue a kernel. Stored here to
|
/// The kernel launch environment used to issue a kernel. Stored here to
|
||||||
/// ensure it is a valid location while the transfer to the device is
|
/// ensure it is a valid location while the transfer to the device is
|
||||||
/// happening.
|
/// happening.
|
||||||
|
@ -208,7 +208,7 @@ Error initPlugins(OffloadContext &Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Error olInit_impl() {
|
Error olInit_impl() {
|
||||||
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
||||||
|
|
||||||
if (isOffloadInitialized()) {
|
if (isOffloadInitialized()) {
|
||||||
OffloadContext::get().RefCount++;
|
OffloadContext::get().RefCount++;
|
||||||
@ -226,7 +226,7 @@ Error olInit_impl() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Error olShutDown_impl() {
|
Error olShutDown_impl() {
|
||||||
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
||||||
|
|
||||||
if (--OffloadContext::get().RefCount != 0)
|
if (--OffloadContext::get().RefCount != 0)
|
||||||
return Error::success();
|
return Error::success();
|
||||||
@ -487,16 +487,13 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
|
|||||||
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
||||||
// on it, but we have nothing to synchronize in that situation anyway.
|
// on it, but we have nothing to synchronize in that situation anyway.
|
||||||
if (Queue->AsyncInfo->Queue) {
|
if (Queue->AsyncInfo->Queue) {
|
||||||
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
|
// We don't need to release the queue and we would like the ability for
|
||||||
|
// other offload threads to submit work concurrently, so pass "false" here
|
||||||
|
// so we don't release the underlying queue object.
|
||||||
|
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
|
||||||
return Err;
|
return Err;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recreate the stream resource so the queue can be reused
|
|
||||||
// TODO: Would be easier for the synchronization to (optionally) not release
|
|
||||||
// it to begin with.
|
|
||||||
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
|
|
||||||
return Res;
|
|
||||||
|
|
||||||
return Error::success();
|
return Error::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -747,7 +744,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
|
|||||||
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
||||||
auto &Device = Program->Image->getDevice();
|
auto &Device = Program->Image->getDevice();
|
||||||
|
|
||||||
std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
|
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
|
||||||
|
|
||||||
switch (Kind) {
|
switch (Kind) {
|
||||||
case OL_SYMBOL_KIND_KERNEL: {
|
case OL_SYMBOL_KIND_KERNEL: {
|
||||||
|
@ -116,7 +116,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
|
|||||||
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
|
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
|
||||||
const char *RegionName) {
|
const char *RegionName) {
|
||||||
assert(PM && "Runtime not initialized");
|
assert(PM && "Runtime not initialized");
|
||||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
|
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
|
||||||
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");
|
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");
|
||||||
|
|
||||||
TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: Data Copy",
|
TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: Data Copy",
|
||||||
@ -311,7 +311,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
|
|||||||
int32_t ThreadLimit, void *HostPtr,
|
int32_t ThreadLimit, void *HostPtr,
|
||||||
KernelArgsTy *KernelArgs) {
|
KernelArgsTy *KernelArgs) {
|
||||||
assert(PM && "Runtime not initialized");
|
assert(PM && "Runtime not initialized");
|
||||||
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
|
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
|
||||||
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
|
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
|
||||||
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
|
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -2232,16 +2232,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
|||||||
/// Get the stream of the asynchronous info structure or get a new one.
|
/// Get the stream of the asynchronous info structure or get a new one.
|
||||||
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
|
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
|
||||||
AMDGPUStreamTy *&Stream) {
|
AMDGPUStreamTy *&Stream) {
|
||||||
// Get the stream (if any) from the async info.
|
auto WrapperStream =
|
||||||
Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
|
AsyncInfoWrapper.getOrInitQueue<AMDGPUStreamTy *>(AMDGPUStreamManager);
|
||||||
if (!Stream) {
|
if (!WrapperStream)
|
||||||
// There was no stream; get an idle one.
|
return WrapperStream.takeError();
|
||||||
if (auto Err = AMDGPUStreamManager.getResource(Stream))
|
Stream = *WrapperStream;
|
||||||
return Err;
|
|
||||||
|
|
||||||
// Modify the async info's stream.
|
|
||||||
AsyncInfoWrapper.setQueueAs<AMDGPUStreamTy *>(Stream);
|
|
||||||
}
|
|
||||||
return Plugin::success();
|
return Plugin::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2296,7 +2291,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Synchronize current thread with the pending operations on the async info.
|
/// Synchronize current thread with the pending operations on the async info.
|
||||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||||
|
bool ReleaseQueue) override {
|
||||||
AMDGPUStreamTy *Stream =
|
AMDGPUStreamTy *Stream =
|
||||||
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
|
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
|
||||||
assert(Stream && "Invalid stream");
|
assert(Stream && "Invalid stream");
|
||||||
@ -2307,8 +2303,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
|||||||
// Once the stream is synchronized, return it to stream pool and reset
|
// Once the stream is synchronized, return it to stream pool and reset
|
||||||
// AsyncInfo. This is to make sure the synchronization only works for its
|
// AsyncInfo. This is to make sure the synchronization only works for its
|
||||||
// own tasks.
|
// own tasks.
|
||||||
AsyncInfo.Queue = nullptr;
|
if (ReleaseQueue) {
|
||||||
return AMDGPUStreamManager.returnResource(Stream);
|
AsyncInfo.Queue = nullptr;
|
||||||
|
return AMDGPUStreamManager.returnResource(Stream);
|
||||||
|
}
|
||||||
|
return Plugin::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Query for the completion of the pending operations on the async info.
|
/// Query for the completion of the pending operations on the async info.
|
||||||
|
@ -60,6 +60,7 @@ struct GenericPluginTy;
|
|||||||
struct GenericKernelTy;
|
struct GenericKernelTy;
|
||||||
struct GenericDeviceTy;
|
struct GenericDeviceTy;
|
||||||
struct RecordReplayTy;
|
struct RecordReplayTy;
|
||||||
|
template <typename ResourceRef> class GenericDeviceResourceManagerTy;
|
||||||
|
|
||||||
namespace Plugin {
|
namespace Plugin {
|
||||||
/// Create a success error. This is the same as calling Error::success(), but
|
/// Create a success error. This is the same as calling Error::success(), but
|
||||||
@ -127,6 +128,20 @@ struct AsyncInfoWrapperTy {
|
|||||||
AsyncInfoPtr->Queue = Queue;
|
AsyncInfoPtr->Queue = Queue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the queue, using the provided resource manager to initialise it if it
|
||||||
|
/// doesn't exist.
|
||||||
|
template <typename Ty, typename RMTy>
|
||||||
|
Expected<Ty>
|
||||||
|
getOrInitQueue(GenericDeviceResourceManagerTy<RMTy> &ResourceManager) {
|
||||||
|
std::lock_guard<std::mutex> Lock(AsyncInfoPtr->Mutex);
|
||||||
|
if (!AsyncInfoPtr->Queue) {
|
||||||
|
if (auto Err = ResourceManager.getResource(
|
||||||
|
*reinterpret_cast<Ty *>(&AsyncInfoPtr->Queue)))
|
||||||
|
return Err;
|
||||||
|
}
|
||||||
|
return getQueueAs<Ty>();
|
||||||
|
}
|
||||||
|
|
||||||
/// Synchronize with the __tgt_async_info's pending operations if it's the
|
/// Synchronize with the __tgt_async_info's pending operations if it's the
|
||||||
/// internal async info. The error associated to the asynchronous operations
|
/// internal async info. The error associated to the asynchronous operations
|
||||||
/// issued in this queue must be provided in \p Err. This function will update
|
/// issued in this queue must be provided in \p Err. This function will update
|
||||||
@ -138,6 +153,7 @@ struct AsyncInfoWrapperTy {
|
|||||||
/// Register \p Ptr as an associated allocation that is freed after
|
/// Register \p Ptr as an associated allocation that is freed after
|
||||||
/// finalization.
|
/// finalization.
|
||||||
void freeAllocationAfterSynchronization(void *Ptr) {
|
void freeAllocationAfterSynchronization(void *Ptr) {
|
||||||
|
std::lock_guard<std::mutex> AllocationGuard(AsyncInfoPtr->Mutex);
|
||||||
AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
|
AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -827,9 +843,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
|
|||||||
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
|
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
|
||||||
|
|
||||||
/// Synchronize the current thread with the pending operations on the
|
/// Synchronize the current thread with the pending operations on the
|
||||||
/// __tgt_async_info structure.
|
/// __tgt_async_info structure. If ReleaseQueue is false, then the
|
||||||
Error synchronize(__tgt_async_info *AsyncInfo);
|
// underlying queue will not be released. In this case, additional
|
||||||
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;
|
// work may be submitted to the queue whilst a synchronize is running.
|
||||||
|
Error synchronize(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true);
|
||||||
|
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||||
|
bool ReleaseQueue) = 0;
|
||||||
|
|
||||||
/// Invokes any global constructors on the device if present and is required
|
/// Invokes any global constructors on the device if present and is required
|
||||||
/// by the target.
|
/// by the target.
|
||||||
|
@ -1335,18 +1335,25 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
|
|||||||
return eraseEntry(*Entry);
|
return eraseEntry(*Entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
|
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
|
||||||
if (!AsyncInfo || !AsyncInfo->Queue)
|
bool ReleaseQueue) {
|
||||||
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
|
SmallVector<void *> AllocsToDelete{};
|
||||||
"invalid async info queue");
|
{
|
||||||
|
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};
|
||||||
|
|
||||||
if (auto Err = synchronizeImpl(*AsyncInfo))
|
if (!AsyncInfo || !AsyncInfo->Queue)
|
||||||
return Err;
|
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
|
||||||
|
"invalid async info queue");
|
||||||
|
|
||||||
for (auto *Ptr : AsyncInfo->AssociatedAllocations)
|
if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
|
||||||
|
return Err;
|
||||||
|
|
||||||
|
std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto *Ptr : AllocsToDelete)
|
||||||
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
|
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
|
||||||
return Err;
|
return Err;
|
||||||
AsyncInfo->AssociatedAllocations.clear();
|
|
||||||
|
|
||||||
return Plugin::success();
|
return Plugin::success();
|
||||||
}
|
}
|
||||||
|
@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy {
|
|||||||
|
|
||||||
/// Get the stream of the asynchronous info structure or get a new one.
|
/// Get the stream of the asynchronous info structure or get a new one.
|
||||||
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
|
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
|
||||||
// Get the stream (if any) from the async info.
|
auto WrapperStream =
|
||||||
Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
|
AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager);
|
||||||
if (!Stream) {
|
if (!WrapperStream)
|
||||||
// There was no stream; get an idle one.
|
return WrapperStream.takeError();
|
||||||
if (auto Err = CUDAStreamManager.getResource(Stream))
|
Stream = *WrapperStream;
|
||||||
return Err;
|
|
||||||
|
|
||||||
// Modify the async info's stream.
|
|
||||||
AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
|
|
||||||
}
|
|
||||||
return Plugin::success();
|
return Plugin::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Synchronize current thread with the pending operations on the async info.
|
/// Synchronize current thread with the pending operations on the async info.
|
||||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||||
|
bool ReleaseQueue) override {
|
||||||
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
|
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
|
||||||
CUresult Res;
|
CUresult Res;
|
||||||
Res = cuStreamSynchronize(Stream);
|
Res = cuStreamSynchronize(Stream);
|
||||||
|
|
||||||
// Once the stream is synchronized, return it to stream pool and reset
|
// Once the stream is synchronized and we want to release the queue, return
|
||||||
// AsyncInfo. This is to make sure the synchronization only works for its
|
// it to stream pool and reset AsyncInfo. This is to make sure the
|
||||||
// own tasks.
|
// synchronization only works for its own tasks.
|
||||||
AsyncInfo.Queue = nullptr;
|
if (ReleaseQueue) {
|
||||||
if (auto Err = CUDAStreamManager.returnResource(Stream))
|
AsyncInfo.Queue = nullptr;
|
||||||
return Err;
|
if (auto Err = CUDAStreamManager.returnResource(Stream))
|
||||||
|
return Err;
|
||||||
|
}
|
||||||
|
|
||||||
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
|
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
|
||||||
}
|
}
|
||||||
|
@ -297,7 +297,8 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
|
|||||||
|
|
||||||
/// All functions are already synchronous. No need to do anything on this
|
/// All functions are already synchronous. No need to do anything on this
|
||||||
/// synchronization function.
|
/// synchronization function.
|
||||||
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
|
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
|
||||||
|
bool ReleaseQueue) override {
|
||||||
return Plugin::success();
|
return Plugin::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include <OffloadAPI.h>
|
#include <OffloadAPI.h>
|
||||||
#include <OffloadPrint.hpp>
|
#include <OffloadPrint.hpp>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
#include "Environment.hpp"
|
#include "Environment.hpp"
|
||||||
|
|
||||||
@ -57,6 +58,23 @@ inline std::string SanitizeString(const std::string &Str) {
|
|||||||
return NewStr;
|
return NewStr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Fn> inline void threadify(Fn body) {
|
||||||
|
std::vector<std::thread> Threads;
|
||||||
|
for (size_t I = 0; I < 20; I++) {
|
||||||
|
Threads.emplace_back(
|
||||||
|
[&body](size_t I) {
|
||||||
|
std::string ScopeMsg{"Thread #"};
|
||||||
|
ScopeMsg.append(std::to_string(I));
|
||||||
|
SCOPED_TRACE(ScopeMsg);
|
||||||
|
body(I);
|
||||||
|
},
|
||||||
|
I);
|
||||||
|
}
|
||||||
|
for (auto &T : Threads) {
|
||||||
|
T.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct OffloadTest : ::testing::Test {
|
struct OffloadTest : ::testing::Test {
|
||||||
ol_device_handle_t Host = TestEnvironment::getHostDevice();
|
ol_device_handle_t Host = TestEnvironment::getHostDevice();
|
||||||
};
|
};
|
||||||
|
@ -104,6 +104,29 @@ TEST_P(olLaunchKernelFooTest, Success) {
|
|||||||
ASSERT_SUCCESS(olMemFree(Mem));
|
ASSERT_SUCCESS(olMemFree(Mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
|
||||||
|
threadify([&](size_t) {
|
||||||
|
void *Mem;
|
||||||
|
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||||
|
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
|
||||||
|
struct {
|
||||||
|
void *Mem;
|
||||||
|
} Args{Mem};
|
||||||
|
|
||||||
|
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
|
||||||
|
&LaunchArgs));
|
||||||
|
|
||||||
|
ASSERT_SUCCESS(olSyncQueue(Queue));
|
||||||
|
|
||||||
|
uint32_t *Data = (uint32_t *)Mem;
|
||||||
|
for (uint32_t i = 0; i < 64; i++) {
|
||||||
|
ASSERT_EQ(Data[i], i);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_SUCCESS(olMemFree(Mem));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(olLaunchKernelNoArgsTest, Success) {
|
TEST_P(olLaunchKernelNoArgsTest, Success) {
|
||||||
ASSERT_SUCCESS(
|
ASSERT_SUCCESS(
|
||||||
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));
|
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user