1211 lines
43 KiB
C++
1211 lines
43 KiB
C++
//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This contains the definitions of the new LLVM/Offload API entry points. See
|
|
// new-api/API/README.md for more information.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "OffloadImpl.hpp"
|
|
#include "Helpers.hpp"
|
|
#include "OffloadPrint.hpp"
|
|
#include "PluginManager.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <OffloadAPI.h>
|
|
|
|
#include <cstdint>
|
|
#include <mutex>
|
|
|
|
// TODO: Some plugins expect to be linked into libomptarget which defines these
|
|
// symbols to implement ompt callbacks. The least invasive workaround here is to
|
|
// define them in libLLVMOffload as false/null so they are never used. In future
|
|
// it would be better to allow the plugins to implement callbacks without
|
|
// pulling in details from libomptarget.
|
|
#ifdef OMPT_SUPPORT
|
|
namespace llvm::omp::target {
|
|
namespace ompt {
|
|
bool Initialized = false;
|
|
ompt_get_callback_t lookupCallbackByCode = nullptr;
|
|
ompt_function_lookup_t lookupCallbackByName = nullptr;
|
|
} // namespace ompt
|
|
} // namespace llvm::omp::target
|
|
#endif
|
|
|
|
using namespace llvm::omp::target;
|
|
using namespace llvm::omp::target::plugin;
|
|
using namespace error;
|
|
|
|
struct ol_platform_impl_t {
|
|
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
|
|
ol_platform_backend_t BackendType)
|
|
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
|
|
ol_platform_backend_t BackendType;
|
|
|
|
/// Complete all pending work for this platform and perform any needed
|
|
/// cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this platform handle.
|
|
llvm::Error destroy();
|
|
|
|
/// Initialize the associated plugin and devices.
|
|
llvm::Error init();
|
|
|
|
/// Direct access to the plugin, may be uninitialized if accessed here.
|
|
std::unique_ptr<GenericPluginTy> Plugin;
|
|
|
|
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
|
|
};
|
|
|
|
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
|
|
// we add some additional data here for now to avoid churn in the plugin
|
|
// interface.
|
|
struct ol_device_impl_t {
|
|
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
|
|
ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
|
|
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
|
|
Info(std::forward<InfoTreeNode>(DevInfo)) {}
|
|
|
|
~ol_device_impl_t() {
|
|
assert(!OutstandingQueues.size() &&
|
|
"Device object dropped with outstanding queues");
|
|
}
|
|
|
|
int DeviceNum;
|
|
GenericDeviceTy *Device;
|
|
ol_platform_impl_t &Platform;
|
|
InfoTreeNode Info;
|
|
|
|
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
|
|
std::mutex OutstandingQueuesMutex;
|
|
|
|
/// If the device has any outstanding queues that are now complete, remove it
|
|
/// from the list and return it.
|
|
///
|
|
/// Queues may be added to the outstanding queue list by olDestroyQueue if
|
|
/// they are destroyed but not completed.
|
|
__tgt_async_info *getOutstandingQueue() {
|
|
// Not locking the `size()` access is fine here - In the worst case we
|
|
// either miss a queue that exists or loop through an empty array after
|
|
// taking the lock. Both are sub-optimal but not that bad.
|
|
if (OutstandingQueues.size()) {
|
|
std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex);
|
|
|
|
// As queues are pulled and popped from this list, longer running queues
|
|
// naturally bubble to the start of the array. Hence looping backwards.
|
|
for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend();
|
|
Q++) {
|
|
if (!Device->hasPendingWork(*Q)) {
|
|
auto OutstandingQueue = *Q;
|
|
*Q = OutstandingQueues.back();
|
|
OutstandingQueues.pop_back();
|
|
return OutstandingQueue;
|
|
}
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Complete all pending work for this device and perform any needed cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this device handle.
|
|
llvm::Error destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto Q : OutstandingQueues)
|
|
if (auto Err = Device->synchronize(Q, /*Release=*/true))
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
OutstandingQueues.clear();
|
|
return Result;
|
|
}
|
|
};
|
|
|
|
llvm::Error ol_platform_impl_t::destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto &D : Devices)
|
|
if (auto Err = D->destroy())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
|
|
if (auto Res = Plugin->deinit())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Res));
|
|
|
|
return Result;
|
|
}
|
|
|
|
llvm::Error ol_platform_impl_t::init() {
|
|
if (!Plugin)
|
|
return llvm::Error::success();
|
|
|
|
if (llvm::Error Err = Plugin->init())
|
|
return Err;
|
|
|
|
for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
|
|
if (llvm::Error Err = Plugin->initDevice(Id))
|
|
return Err;
|
|
|
|
GenericDeviceTy *Device = &Plugin->getDevice(Id);
|
|
llvm::Expected<InfoTreeNode> Info = Device->obtainInfo();
|
|
if (llvm::Error Err = Info.takeError())
|
|
return Err;
|
|
Devices.emplace_back(std::make_unique<ol_device_impl_t>(Id, Device, *this,
|
|
std::move(*Info)));
|
|
}
|
|
|
|
return llvm::Error::success();
|
|
}
|
|
|
|
struct ol_queue_impl_t {
|
|
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
|
|
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
|
|
__tgt_async_info *AsyncInfo;
|
|
ol_device_handle_t Device;
|
|
// A unique identifier for the queue
|
|
size_t Id;
|
|
static std::atomic<size_t> IdCounter;
|
|
};
|
|
std::atomic<size_t> ol_queue_impl_t::IdCounter(0);
|
|
|
|
struct ol_event_impl_t {
|
|
ol_event_impl_t(void *EventInfo, ol_device_handle_t Device,
|
|
ol_queue_handle_t Queue)
|
|
: EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
|
|
}
|
|
// EventInfo may be null, in which case the event should be considered always
|
|
// complete
|
|
void *EventInfo;
|
|
ol_device_handle_t Device;
|
|
size_t QueueId;
|
|
// Events may outlive the queue - don't assume this is always valid.
|
|
// It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
|
|
// for queue equality instead.
|
|
ol_queue_handle_t Queue;
|
|
};
|
|
|
|
struct ol_program_impl_t {
|
|
ol_program_impl_t(plugin::DeviceImageTy *Image,
|
|
llvm::MemoryBufferRef DeviceImage)
|
|
: Image(Image), DeviceImage(DeviceImage) {}
|
|
plugin::DeviceImageTy *Image;
|
|
std::mutex SymbolListMutex;
|
|
llvm::MemoryBufferRef DeviceImage;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
|
|
};
|
|
|
|
struct ol_symbol_impl_t {
|
|
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
|
|
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
|
|
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
|
|
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
|
|
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
|
|
ol_symbol_kind_t Kind;
|
|
llvm::StringRef Name;
|
|
};
|
|
|
|
namespace llvm {
|
|
namespace offload {
|
|
|
|
struct AllocInfo {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
void *Start;
|
|
// One byte past the end
|
|
void *End;
|
|
};
|
|
|
|
// Global shared state for liboffload
|
|
struct OffloadContext;
|
|
// This pointer is non-null if and only if the context is valid and fully
|
|
// initialized
|
|
static std::atomic<OffloadContext *> OffloadContextVal;
|
|
std::mutex OffloadContextValMutex;
|
|
struct OffloadContext {
|
|
OffloadContext(OffloadContext &) = delete;
|
|
OffloadContext(OffloadContext &&) = delete;
|
|
OffloadContext &operator=(OffloadContext &) = delete;
|
|
OffloadContext &operator=(OffloadContext &&) = delete;
|
|
|
|
bool TracingEnabled = false;
|
|
bool ValidationEnabled = true;
|
|
DenseMap<void *, AllocInfo> AllocInfoMap{};
|
|
std::mutex AllocInfoMapMutex{};
|
|
// Partitioned list of memory base addresses. Each element in this list is a
|
|
// key in AllocInfoMap
|
|
SmallVector<void *> AllocBases{};
|
|
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
|
|
size_t RefCount;
|
|
|
|
static OffloadContext &get() {
|
|
assert(OffloadContextVal);
|
|
return *OffloadContextVal;
|
|
}
|
|
};
|
|
|
|
// If the context is uninited, then we assume tracing is disabled
|
|
bool isTracingEnabled() {
|
|
return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
|
|
}
|
|
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
|
|
bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
|
|
|
|
template <typename HandleT> Error olDestroy(HandleT Handle) {
|
|
delete Handle;
|
|
return Error::success();
|
|
}
|
|
|
|
constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
|
|
if (Name == "amdgpu") {
|
|
return OL_PLATFORM_BACKEND_AMDGPU;
|
|
} else if (Name == "cuda") {
|
|
return OL_PLATFORM_BACKEND_CUDA;
|
|
} else if (Name == "host") {
|
|
return OL_PLATFORM_BACKEND_HOST;
|
|
} else if (Name == "level_zero") {
|
|
return OL_PLATFORM_BACKEND_LEVEL_ZERO;
|
|
} else {
|
|
return OL_PLATFORM_BACKEND_UNKNOWN;
|
|
}
|
|
}
|
|
|
|
// Every plugin exports this method to create an instance of the plugin type.
|
|
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
|
|
#include "Shared/Targets.def"
|
|
|
|
Error initPlugins(OffloadContext &Context, const ol_init_args_t *InitArgs) {
|
|
SmallSet<ol_platform_backend_t, 0> Requested;
|
|
if (InitArgs && InitArgs->NumPlatforms > 0)
|
|
for (uint32_t I = 0; I < InitArgs->NumPlatforms; I++)
|
|
Requested.insert(InitArgs->Platforms[I]);
|
|
|
|
// Attempt to create an instance of each supported plugin, skipping
|
|
// unrequested backends. The host plugin is always created.
|
|
#define PLUGIN_TARGET(Name) \
|
|
do { \
|
|
auto Backend = pluginNameToBackend(#Name); \
|
|
if (Requested.empty() || Backend == OL_PLATFORM_BACKEND_HOST || \
|
|
Requested.contains(Backend)) { \
|
|
Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
|
|
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), Backend)); \
|
|
} \
|
|
} while (false);
|
|
#include "Shared/Targets.def"
|
|
|
|
// Eagerly initialize all of the plugins and devices. We need to make sure
|
|
// that the platform is initialized at a consistent point to maintain the
|
|
// expected teardown order in the vendor libraries.
|
|
for (auto &Platform : Context.Platforms) {
|
|
if (Error Err = Platform->init())
|
|
return Err;
|
|
}
|
|
|
|
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
|
|
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
|
|
|
|
return Plugin::success();
|
|
}
|
|
|
|
Error olInit_impl(const ol_init_args_t *InitArgs) {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (isOffloadInitialized()) {
|
|
OffloadContext::get().RefCount++;
|
|
return Plugin::success();
|
|
}
|
|
|
|
if (InitArgs) {
|
|
if (InitArgs->Size < sizeof(ol_init_args_t))
|
|
return createOffloadError(ErrorCode::INVALID_SIZE,
|
|
"ol_init_args_t Size field is too small");
|
|
if (InitArgs->NumPlatforms > 0 && !InitArgs->Platforms)
|
|
return createOffloadError(ErrorCode::INVALID_NULL_POINTER,
|
|
"NumPlatforms > 0 but Platforms is null");
|
|
}
|
|
|
|
// Use a temporary to ensure that entry points querying OffloadContextVal do
|
|
// not get a partially initialized context
|
|
auto *NewContext = new OffloadContext{};
|
|
Error InitResult = initPlugins(*NewContext, InitArgs);
|
|
OffloadContextVal.store(NewContext);
|
|
OffloadContext::get().RefCount++;
|
|
|
|
return InitResult;
|
|
}
|
|
|
|
Error olShutDown_impl() {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (--OffloadContext::get().RefCount != 0)
|
|
return Error::success();
|
|
|
|
Error Result = Error::success();
|
|
auto *OldContext = OffloadContextVal.exchange(nullptr);
|
|
|
|
for (auto &Platform : OldContext->Platforms) {
|
|
// Host plugin is nullptr and has no deinit
|
|
if (!Platform->Plugin || !Platform->Plugin->is_initialized())
|
|
continue;
|
|
|
|
if (auto Res = Platform->destroy())
|
|
Result = joinErrors(std::move(Result), std::move(Res));
|
|
}
|
|
|
|
delete OldContext;
|
|
return Result;
|
|
}
|
|
|
|
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
// Note that the plugin is potentially uninitialized here. It will need to be
|
|
// initialized once info is added that requires it to be initialized.
|
|
switch (PropName) {
|
|
case OL_PLATFORM_INFO_NAME:
|
|
return Info.writeString(Platform->Plugin->getName());
|
|
case OL_PLATFORM_INFO_VENDOR_NAME:
|
|
// TODO: Implement this
|
|
return Info.writeString("Unknown platform vendor");
|
|
case OL_PLATFORM_INFO_VERSION: {
|
|
return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
|
|
OL_VERSION_MINOR, OL_VERSION_PATCH)
|
|
.str());
|
|
}
|
|
case OL_PLATFORM_INFO_BACKEND: {
|
|
return Info.write<ol_platform_backend_t>(Platform->BackendType);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getPlatformInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
|
|
PropSizeRet);
|
|
}
|
|
|
|
Error olPlatformRegisterRPCCallback_impl(ol_platform_handle_t Platform,
|
|
ol_platform_rpc_cb_t Callback) {
|
|
Platform->Plugin->getRPCServer().registerCallback(Callback);
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto makeError = [&](ErrorCode Code, StringRef Err) {
|
|
std::string ErrBuffer;
|
|
raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
|
|
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
|
|
};
|
|
bool IsHost = Device->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
// These are not implemented by the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_PLATFORM:
|
|
return Info.write<void *>(&Device->Platform);
|
|
|
|
case OL_DEVICE_INFO_TYPE:
|
|
if (IsHost)
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
|
|
else
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
|
|
|
|
case OL_DEVICE_INFO_SINGLE_FP_CONFIG:
|
|
case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: {
|
|
ol_device_fp_capability_flags_t flags{0};
|
|
flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_DENORM |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_FMA;
|
|
return Info.write(flags);
|
|
}
|
|
|
|
case OL_DEVICE_INFO_HALF_FP_CONFIG:
|
|
return Info.write<ol_device_fp_capability_flags_t>(0);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE:
|
|
return Info.write<uint32_t>(1);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF:
|
|
return Info.write<uint32_t>(0);
|
|
|
|
// None of the existing plugins specify a limit on a single allocation,
|
|
// so return the global memory size instead
|
|
case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
|
|
[[fallthrough]];
|
|
// AMD doesn't provide the global memory size (trivially) with the device info
|
|
// struct, so use the plugin interface
|
|
case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: {
|
|
uint64_t Mem;
|
|
if (auto Err = Device->Device->getDeviceMemorySize(Mem))
|
|
return Err;
|
|
return Info.write<uint64_t>(Mem);
|
|
} break;
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (PropName >= OL_DEVICE_INFO_LAST)
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getDeviceInfo enum '%i' is invalid", PropName);
|
|
|
|
auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
|
|
if (!EntryOpt)
|
|
return makeError(ErrorCode::UNIMPLEMENTED,
|
|
"plugin did not provide a response for this information");
|
|
auto Entry = *EntryOpt;
|
|
|
|
// Retrieve properties from the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_NAME:
|
|
case OL_DEVICE_INFO_PRODUCT_NAME:
|
|
case OL_DEVICE_INFO_UID:
|
|
case OL_DEVICE_INFO_VENDOR:
|
|
case OL_DEVICE_INFO_DRIVER_VERSION: {
|
|
// String values
|
|
if (!std::holds_alternative<std::string>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE:
|
|
case OL_DEVICE_INFO_VENDOR_ID:
|
|
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
|
|
case OL_DEVICE_INFO_ADDRESS_BITS:
|
|
case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY:
|
|
case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: {
|
|
// Uint32 values
|
|
if (!std::holds_alternative<uint64_t>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
auto Value = std::get<uint64_t>(Entry->Value);
|
|
if (Value > std::numeric_limits<uint32_t>::max())
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned out of range device info");
|
|
return Info.write(static_cast<uint32_t>(Value));
|
|
}
|
|
|
|
case OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE: {
|
|
if (!std::holds_alternative<uint64_t>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
return Info.write(std::get<uint64_t>(Entry->Value));
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: {
|
|
// {x, y, z} triples
|
|
ol_dimensions_t Out{0, 0, 0};
|
|
|
|
auto getField = [&](StringRef Name, uint32_t &Dest) {
|
|
if (auto F = Entry->get(Name)) {
|
|
if (!std::holds_alternative<uint64_t>((*F)->Value))
|
|
return makeError(
|
|
ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type for dimensions element");
|
|
Dest = std::get<uint64_t>((*F)->Value);
|
|
} else
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin didn't provide all values for dimensions");
|
|
return Plugin::success();
|
|
};
|
|
|
|
if (auto Res = getField("x", Out.x))
|
|
return Res;
|
|
if (auto Res = getField("y", Out.y))
|
|
return Res;
|
|
if (auto Res = getField("z", Out.z))
|
|
return Res;
|
|
|
|
return Info.write(Out);
|
|
}
|
|
|
|
default:
|
|
llvm_unreachable("Unimplemented device info");
|
|
}
|
|
}
|
|
|
|
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t *PropSizeRet) {
|
|
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
|
|
for (auto &Platform : OffloadContext::get().Platforms) {
|
|
for (auto &Device : Platform->Devices) {
|
|
if (!Callback(Device.get(), UserData)) {
|
|
return Error::success();
|
|
}
|
|
}
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
|
|
switch (Type) {
|
|
case OL_ALLOC_TYPE_DEVICE:
|
|
return TARGET_ALLOC_DEVICE;
|
|
case OL_ALLOC_TYPE_HOST:
|
|
return TARGET_ALLOC_HOST;
|
|
case OL_ALLOC_TYPE_MANAGED:
|
|
default:
|
|
return TARGET_ALLOC_SHARED;
|
|
}
|
|
}
|
|
|
|
constexpr size_t MAX_ALLOC_TRIES = 50;
|
|
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
|
|
size_t Size, void **AllocationOut) {
|
|
SmallVector<void *> Rejects;
|
|
|
|
// Repeat the allocation up to a certain amount of times. If it happens to
|
|
// already be allocated (e.g. by a device from another vendor) throw it away
|
|
// and try again.
|
|
for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) {
|
|
auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
|
|
convertOlToPluginAllocTy(Type));
|
|
if (!NewAlloc)
|
|
return NewAlloc.takeError();
|
|
|
|
void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
// Check that this memory region doesn't overlap another one
|
|
// That is, the start of this allocation needs to be after another
|
|
// allocation's end point, and the end of this allocation needs to be
|
|
// before the next one's start.
|
|
// `Gap` is the first alloc who ends after the new alloc's start point.
|
|
auto Gap =
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
|
|
// Success, no conflict
|
|
AllocInfoMap.insert_or_assign(
|
|
*NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
|
|
AllocBases.insert(
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
|
|
*NewAlloc);
|
|
*AllocationOut = *NewAlloc;
|
|
|
|
for (void *R : Rejects)
|
|
if (auto Err =
|
|
Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type)))
|
|
return Err;
|
|
return Error::success();
|
|
}
|
|
|
|
// To avoid the next attempt allocating the same memory we just freed, we
|
|
// hold onto it until we complete the allocation
|
|
Rejects.push_back(*NewAlloc);
|
|
}
|
|
}
|
|
|
|
// We've tried multiple times, and can't allocate a non-overlapping region.
|
|
return createOffloadError(ErrorCode::BACKEND_FAILURE,
|
|
"failed to allocate non-overlapping memory");
|
|
}
|
|
|
|
Error olMemFree_impl(void *Address) {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
if (!OffloadContext::get().AllocInfoMap.contains(Address))
|
|
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
|
|
"address is not a known allocation");
|
|
|
|
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
|
|
Device = AllocInfo.Device;
|
|
Type = AllocInfo.Type;
|
|
OffloadContext::get().AllocInfoMap.erase(Address);
|
|
|
|
auto &Bases = OffloadContext::get().AllocBases;
|
|
Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
|
|
}
|
|
|
|
if (auto Res =
|
|
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue,
|
|
size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
const AllocInfo *Alloc = nullptr;
|
|
if (AllocInfoMap.contains(Ptr)) {
|
|
// Fast case, we have been given the base pointer directly
|
|
Alloc = &AllocInfoMap.at(Ptr);
|
|
} else {
|
|
// Slower case, we need to look up the base pointer first
|
|
// Find the first memory allocation whose end is after the target pointer,
|
|
// and then check to see if it is in range
|
|
auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start)
|
|
return Plugin::error(ErrorCode::NOT_FOUND,
|
|
"allocated memory information not found");
|
|
Alloc = &AllocInfoMap.at(*Loc);
|
|
}
|
|
|
|
switch (PropName) {
|
|
case OL_MEM_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Alloc->Device);
|
|
case OL_MEM_INFO_BASE:
|
|
return Info.write<void *>(Alloc->Start);
|
|
case OL_MEM_INFO_SIZE:
|
|
return Info.write<size_t>(static_cast<char *>(Alloc->End) -
|
|
static_cast<char *>(Alloc->Start));
|
|
case OL_MEM_INFO_TYPE:
|
|
return Info.write<ol_alloc_type_t>(Alloc->Type);
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetMemInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr);
|
|
}
|
|
|
|
Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
|
|
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
|
|
|
|
auto OutstandingQueue = Device->getOutstandingQueue();
|
|
if (OutstandingQueue) {
|
|
// The queue is empty, but we still need to sync it to release any temporary
|
|
// memory allocations or do other cleanup.
|
|
if (auto Err =
|
|
Device->Device->synchronize(OutstandingQueue, /*Release=*/false))
|
|
return Err;
|
|
CreatedQueue->AsyncInfo = OutstandingQueue;
|
|
} else if (auto Err =
|
|
Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) {
|
|
return Err;
|
|
}
|
|
|
|
*Queue = CreatedQueue.release();
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyQueue_impl(ol_queue_handle_t Queue) {
|
|
auto *Device = Queue->Device;
|
|
// This is safe; as soon as olDestroyQueue is called it is not possible to add
|
|
// any more work to the queue, so if it's finished now it will remain finished
|
|
// forever.
|
|
auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
|
|
if (!*Res) {
|
|
// The queue is complete, so sync it and throw it back into the pool.
|
|
if (auto Err = Device->Device->synchronize(Queue->AsyncInfo,
|
|
/*Release=*/true))
|
|
return Err;
|
|
} else {
|
|
// The queue still has outstanding work. Store it so we can check it later.
|
|
std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex);
|
|
Device->OutstandingQueues.push_back(Queue->AsyncInfo);
|
|
}
|
|
|
|
return olDestroy(Queue);
|
|
}
|
|
|
|
Error olSyncQueue_impl(ol_queue_handle_t Queue) {
|
|
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
|
// on it, but we have nothing to synchronize in that situation anyway.
|
|
if (Queue->AsyncInfo->Queue) {
|
|
// We don't need to release the queue and we would like the ability for
|
|
// other offload threads to submit work concurrently, so pass "false" here
|
|
// so we don't release the underlying queue object.
|
|
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
|
|
size_t NumEvents) {
|
|
auto *Device = Queue->Device->Device;
|
|
|
|
for (size_t I = 0; I < NumEvents; I++) {
|
|
auto *Event = Events[I];
|
|
|
|
if (!Event)
|
|
return Plugin::error(ErrorCode::INVALID_NULL_HANDLE,
|
|
"olWaitEvents asked to wait on a NULL event");
|
|
|
|
// Do nothing if the event is for this queue or the event is always complete
|
|
if (Event->QueueId == Queue->Id || !Event->EventInfo)
|
|
continue;
|
|
|
|
if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
|
|
ol_queue_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
switch (PropName) {
|
|
case OL_QUEUE_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Queue->Device);
|
|
case OL_QUEUE_INFO_EMPTY: {
|
|
auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (auto Err = Pending.takeError())
|
|
return Err;
|
|
return Info.write<bool>(!*Pending);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetQueueInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olSyncEvent_impl(ol_event_handle_t Event) {
|
|
// No event info means that this event was complete on creation
|
|
if (!Event->EventInfo)
|
|
return Plugin::success();
|
|
|
|
if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyEvent_impl(ol_event_handle_t Event) {
|
|
if (Event->EventInfo)
|
|
if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return olDestroy(Event);
|
|
}
|
|
|
|
Error olGetEventInfoImplDetail(ol_event_handle_t Event,
|
|
ol_event_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
auto Queue = Event->Queue;
|
|
|
|
switch (PropName) {
|
|
case OL_EVENT_INFO_QUEUE:
|
|
return Info.write<ol_queue_handle_t>(Queue);
|
|
case OL_EVENT_INFO_IS_COMPLETE: {
|
|
// No event info means that this event was complete on creation
|
|
if (!Event->EventInfo)
|
|
return Info.write<bool>(true);
|
|
|
|
auto Res = Queue->Device->Device->isEventComplete(Event->EventInfo,
|
|
Queue->AsyncInfo);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
return Info.write<bool>(*Res);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetEventInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
|
|
auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (auto Err = Pending.takeError())
|
|
return Err;
|
|
|
|
*EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue);
|
|
if (!*Pending)
|
|
// Queue is empty, don't record an event and consider the event always
|
|
// complete
|
|
return Plugin::success();
|
|
|
|
if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo))
|
|
return Res;
|
|
|
|
if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo,
|
|
Queue->AsyncInfo))
|
|
return Res;
|
|
|
|
return Plugin::success();
|
|
}
|
|
|
|
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
|
|
ol_device_handle_t DstDevice, const void *SrcPtr,
|
|
ol_device_handle_t SrcDevice, size_t Size) {
|
|
bool IsDstHost = DstDevice->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
bool IsSrcHost = SrcDevice->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
|
|
if (IsDstHost && IsSrcHost) {
|
|
if (!Queue) {
|
|
std::memcpy(DstPtr, SrcPtr, Size);
|
|
return Error::success();
|
|
} else {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_ARGUMENT,
|
|
"ane of DstDevice and SrcDevice must be a non-host device if "
|
|
"queue is specified");
|
|
}
|
|
}
|
|
|
|
// If no queue is given the memcpy will be synchronous
|
|
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
|
|
if (IsDstHost) {
|
|
if (auto Res =
|
|
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else if (IsSrcHost) {
|
|
if (auto Res =
|
|
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else if (SrcDevice->Platform.Plugin == DstDevice->Platform.Plugin &&
|
|
SrcDevice->Platform.Plugin->isDataExchangable(
|
|
SrcDevice->Device->getDeviceId(),
|
|
DstDevice->Device->getDeviceId())) {
|
|
if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
|
|
DstPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else {
|
|
if (Queue)
|
|
if (auto Res = olSyncQueue_impl(Queue))
|
|
return Res;
|
|
|
|
void *Buffer = malloc(Size);
|
|
if (!Buffer)
|
|
return createOffloadError(ErrorCode::OUT_OF_RESOURCES,
|
|
"Couldn't allocate a buffer for transfer");
|
|
Error Res = SrcDevice->Device->dataRetrieve(Buffer, SrcPtr, Size, nullptr);
|
|
if (!Res)
|
|
Res = DstDevice->Device->dataSubmit(DstPtr, Buffer, Size, nullptr);
|
|
|
|
free(Buffer);
|
|
return Res;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
|
|
const void *PatternPtr, size_t FillSize) {
|
|
return Queue->Device->Device->dataFill(Ptr, PatternPtr, PatternSize, FillSize,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, ol_program_handle_t *Program) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
Expected<plugin::DeviceImageTy *> Res =
|
|
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
assert(*Res && "loadBinary returned nullptr");
|
|
|
|
*Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer());
|
|
return Error::success();
|
|
}
|
|
|
|
Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, bool *IsValid) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
*IsValid = Device->Device ? Device->Device->Plugin.isDeviceCompatible(
|
|
Device->Device->getDeviceId(), Buffer)
|
|
: false;
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyProgram_impl(ol_program_handle_t Program) {
|
|
auto &Device = Program->Image->getDevice();
|
|
if (auto Err = Device.unloadBinary(Program->Image))
|
|
return Err;
|
|
|
|
auto &LoadedImages = Device.LoadedImages;
|
|
LoadedImages.erase(
|
|
std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image));
|
|
|
|
return olDestroy(Program);
|
|
}
|
|
|
|
Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel,
|
|
size_t DynamicMemSize,
|
|
size_t *GroupSize) {
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
|
|
auto Res = KernelImpl->maxGroupSize(*Device->Device, DynamicMemSize);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
|
|
*GroupSize = *Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel, const void *ArgumentsData,
|
|
size_t ArgumentsSize,
|
|
const ol_kernel_launch_size_args_t *LaunchSizeArgs) {
|
|
auto *DeviceImpl = Device->Device;
|
|
if (Queue && Device != Queue->Device) {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_DEVICE,
|
|
"device specified does not match the device of the given queue");
|
|
}
|
|
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
|
|
auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
|
|
KernelArgsTy LaunchArgs{};
|
|
LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x;
|
|
LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y;
|
|
LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z;
|
|
LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x;
|
|
LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y;
|
|
LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
|
|
LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
|
|
|
|
KernelLaunchParamsTy Params;
|
|
Params.Data = const_cast<void *>(ArgumentsData);
|
|
Params.Size = ArgumentsSize;
|
|
LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
|
|
// Don't do anything with pointer indirection; use arg data as-is
|
|
LaunchArgs.Flags.IsCUDA = true;
|
|
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
|
|
LaunchArgs, AsyncInfoWrapper);
|
|
|
|
AsyncInfoWrapper.finalize(Err);
|
|
if (Err)
|
|
return Err;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
|
|
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
|
auto &Device = Program->Image->getDevice();
|
|
|
|
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
|
|
|
|
switch (Kind) {
|
|
case OL_SYMBOL_KIND_KERNEL: {
|
|
auto &Kernel = Program->KernelSymbols[Name];
|
|
if (!Kernel) {
|
|
auto KernelImpl = Device.constructKernel(Name);
|
|
if (!KernelImpl)
|
|
return KernelImpl.takeError();
|
|
|
|
if (auto Err = KernelImpl->init(Device, *Program->Image))
|
|
return Err;
|
|
|
|
Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(),
|
|
&*KernelImpl);
|
|
}
|
|
|
|
*Symbol = Kernel.get();
|
|
return Error::success();
|
|
}
|
|
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
|
|
auto &Global = Program->GlobalSymbols[Name];
|
|
if (!Global) {
|
|
GlobalTy GlobalObj{Name};
|
|
if (auto Res =
|
|
Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
|
|
Device, *Program->Image, GlobalObj))
|
|
return Res;
|
|
|
|
Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(),
|
|
std::move(GlobalObj));
|
|
}
|
|
|
|
*Symbol = Global.get();
|
|
return Error::success();
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getSymbol kind enum '%i' is invalid", Kind);
|
|
}
|
|
}
|
|
|
|
Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto CheckKind = [&](ol_symbol_kind_t Required) {
|
|
if (Symbol->Kind != Required) {
|
|
std::string ErrBuffer;
|
|
raw_string_ostream(ErrBuffer)
|
|
<< PropName << ": Expected a symbol of Kind " << Required
|
|
<< " but given a symbol of Kind " << Symbol->Kind;
|
|
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
|
|
}
|
|
return Plugin::success();
|
|
};
|
|
|
|
switch (PropName) {
|
|
case OL_SYMBOL_INFO_KIND:
|
|
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetSymbolInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t *PropSizeRet) {
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
|
|
ol_host_function_cb_t Callback,
|
|
void *UserData) {
|
|
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
Error olMemRegister_impl(ol_device_handle_t Device, void *Ptr, size_t Size,
|
|
ol_memory_register_flags_t Flags, void **LockedPtr) {
|
|
Expected<void *> LockedPtrOrErr = Device->Device->registerMemory(
|
|
Ptr, Size, Flags & OL_MEMORY_REGISTER_FLAG_LOCK_MEMORY);
|
|
if (!LockedPtrOrErr)
|
|
return LockedPtrOrErr.takeError();
|
|
|
|
*LockedPtr = *LockedPtrOrErr;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemUnregister_impl(ol_device_handle_t Device, void *Ptr,
|
|
ol_memory_register_flags_t Flags) {
|
|
return Device->Device->unregisterMemory(
|
|
Ptr, Flags & OL_MEMORY_REGISTER_FLAG_UNLOCK_MEMORY);
|
|
}
|
|
|
|
Error olQueryQueue_impl(ol_queue_handle_t Queue, bool *IsQueueWorkCompleted) {
|
|
if (Queue->AsyncInfo->Queue) {
|
|
if (auto Err = Queue->Device->Device->queryAsync(Queue->AsyncInfo, false,
|
|
IsQueueWorkCompleted))
|
|
return Err;
|
|
} else if (IsQueueWorkCompleted) {
|
|
// No underlying queue means there's no work to complete.
|
|
*IsQueueWorkCompleted = true;
|
|
}
|
|
return Error::success();
|
|
}
|
|
|
|
} // namespace offload
|
|
} // namespace llvm
|