Summary: The Offloading library wraps around the underlying plugins. The problem is that we currently initialize all plugins we find, even if they are not needed for the program. This is very expensive for trivial uses, as fully heterogenous usage is quite rare. In practice this means that you will always pay a 200 ms penalty for having CUDA installed. This patch changes the behavior to provide accessors into the plugins and devices that allows them to be initialized lazily. We use a once_flag, this should properly take a fast-path check while still blocking on concurrent use. Making full use of this will require a way to filter platforms more specifically. I'm thinking of what this would look like as an API. I'm thinking that we either have an extra iterate function that takes a callback on the platform, or we just provide a helper to find all the devices that can run a given image. Maybe both? Fixes: https://github.com/llvm/llvm-project/issues/159636
1237 lines
44 KiB
C++
1237 lines
44 KiB
C++
//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This contains the definitions of the new LLVM/Offload API entry points. See
|
|
// new-api/API/README.md for more information.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "OffloadImpl.hpp"
|
|
#include "Helpers.hpp"
|
|
#include "OffloadPrint.hpp"
|
|
#include "PluginManager.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <OffloadAPI.h>
|
|
|
|
#include <mutex>
|
|
|
|
// TODO: Some plugins expect to be linked into libomptarget which defines these
|
|
// symbols to implement ompt callbacks. The least invasive workaround here is to
|
|
// define them in libLLVMOffload as false/null so they are never used. In future
|
|
// it would be better to allow the plugins to implement callbacks without
|
|
// pulling in details from libomptarget.
|
|
#ifdef OMPT_SUPPORT
|
|
namespace llvm::omp::target {
|
|
namespace ompt {
|
|
bool Initialized = false;
|
|
ompt_get_callback_t lookupCallbackByCode = nullptr;
|
|
ompt_function_lookup_t lookupCallbackByName = nullptr;
|
|
} // namespace ompt
|
|
} // namespace llvm::omp::target
|
|
#endif
|
|
|
|
using namespace llvm::omp::target;
|
|
using namespace llvm::omp::target::plugin;
|
|
using namespace error;
|
|
|
|
struct ol_platform_impl_t {
|
|
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
|
|
ol_platform_backend_t BackendType)
|
|
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
|
|
ol_platform_backend_t BackendType;
|
|
|
|
/// Get the plugin, lazily initializing it if necessary.
|
|
llvm::Expected<GenericPluginTy *> getPlugin() {
|
|
if (llvm::Error Err = init())
|
|
return Err;
|
|
return Plugin.get();
|
|
}
|
|
|
|
/// Get the device list, lazily initializing it if necessary.
|
|
llvm::Expected<llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> &>
|
|
getDevices() {
|
|
if (llvm::Error Err = init())
|
|
return Err;
|
|
return Devices;
|
|
}
|
|
|
|
/// Complete all pending work for this platform and perform any needed
|
|
/// cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this platform handle.
|
|
llvm::Error destroy();
|
|
|
|
/// Initialize the associated plugin and devices.
|
|
llvm::Error init();
|
|
|
|
/// Direct access to the plugin, may be uninitialized if accessed here.
|
|
std::unique_ptr<GenericPluginTy> Plugin;
|
|
|
|
private:
|
|
std::once_flag Initialized;
|
|
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
|
|
};
|
|
|
|
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
|
|
// we add some additional data here for now to avoid churn in the plugin
|
|
// interface.
|
|
struct ol_device_impl_t {
|
|
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
|
|
ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
|
|
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
|
|
Info(std::forward<InfoTreeNode>(DevInfo)) {}
|
|
|
|
~ol_device_impl_t() {
|
|
assert(!OutstandingQueues.size() &&
|
|
"Device object dropped with outstanding queues");
|
|
}
|
|
|
|
int DeviceNum;
|
|
GenericDeviceTy *Device;
|
|
ol_platform_impl_t &Platform;
|
|
InfoTreeNode Info;
|
|
|
|
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
|
|
std::mutex OutstandingQueuesMutex;
|
|
|
|
/// If the device has any outstanding queues that are now complete, remove it
|
|
/// from the list and return it.
|
|
///
|
|
/// Queues may be added to the outstanding queue list by olDestroyQueue if
|
|
/// they are destroyed but not completed.
|
|
__tgt_async_info *getOutstandingQueue() {
|
|
// Not locking the `size()` access is fine here - In the worst case we
|
|
// either miss a queue that exists or loop through an empty array after
|
|
// taking the lock. Both are sub-optimal but not that bad.
|
|
if (OutstandingQueues.size()) {
|
|
std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex);
|
|
|
|
// As queues are pulled and popped from this list, longer running queues
|
|
// naturally bubble to the start of the array. Hence looping backwards.
|
|
for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend();
|
|
Q++) {
|
|
if (!Device->hasPendingWork(*Q)) {
|
|
auto OutstandingQueue = *Q;
|
|
*Q = OutstandingQueues.back();
|
|
OutstandingQueues.pop_back();
|
|
return OutstandingQueue;
|
|
}
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Complete all pending work for this device and perform any needed cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this device handle.
|
|
llvm::Error destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto Q : OutstandingQueues)
|
|
if (auto Err = Device->synchronize(Q, /*Release=*/true))
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
OutstandingQueues.clear();
|
|
return Result;
|
|
}
|
|
};
|
|
|
|
llvm::Error ol_platform_impl_t::destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto &D : Devices)
|
|
if (auto Err = D->destroy())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
|
|
if (auto Res = Plugin->deinit())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Res));
|
|
|
|
return Result;
|
|
}
|
|
|
|
llvm::Error ol_platform_impl_t::init() {
|
|
std::unique_ptr<llvm::Error> Storage;
|
|
|
|
// This can be called concurrently, make sure we only do the actual
|
|
// initialization once.
|
|
std::call_once(Initialized, [&]() {
|
|
// FIXME: Need better handling for the host platform.
|
|
if (!Plugin)
|
|
return;
|
|
|
|
llvm::Error Err = Plugin->init();
|
|
if (Err) {
|
|
Storage = std::make_unique<llvm::Error>(std::move(Err));
|
|
return;
|
|
}
|
|
|
|
for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
|
|
if (llvm::Error Err = Plugin->initDevice(Id)) {
|
|
Storage = std::make_unique<llvm::Error>(std::move(Err));
|
|
return;
|
|
}
|
|
|
|
auto Device = &Plugin->getDevice(Id);
|
|
auto Info = Device->obtainInfoImpl();
|
|
if (llvm::Error Err = Info.takeError()) {
|
|
Storage = std::make_unique<llvm::Error>(std::move(Err));
|
|
return;
|
|
}
|
|
Devices.emplace_back(std::make_unique<ol_device_impl_t>(
|
|
Id, Device, *this, std::move(*Info)));
|
|
}
|
|
});
|
|
|
|
return Storage ? std::move(*Storage) : llvm::Error::success();
|
|
}
|
|
|
|
struct ol_queue_impl_t {
|
|
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
|
|
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
|
|
__tgt_async_info *AsyncInfo;
|
|
ol_device_handle_t Device;
|
|
// A unique identifier for the queue
|
|
size_t Id;
|
|
static std::atomic<size_t> IdCounter;
|
|
};
|
|
std::atomic<size_t> ol_queue_impl_t::IdCounter(0);
|
|
|
|
struct ol_event_impl_t {
|
|
ol_event_impl_t(void *EventInfo, ol_device_handle_t Device,
|
|
ol_queue_handle_t Queue)
|
|
: EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
|
|
}
|
|
// EventInfo may be null, in which case the event should be considered always
|
|
// complete
|
|
void *EventInfo;
|
|
ol_device_handle_t Device;
|
|
size_t QueueId;
|
|
// Events may outlive the queue - don't assume this is always valid.
|
|
// It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
|
|
// for queue equality instead.
|
|
ol_queue_handle_t Queue;
|
|
};
|
|
|
|
struct ol_program_impl_t {
|
|
ol_program_impl_t(plugin::DeviceImageTy *Image,
|
|
llvm::MemoryBufferRef DeviceImage)
|
|
: Image(Image), DeviceImage(DeviceImage) {}
|
|
plugin::DeviceImageTy *Image;
|
|
std::mutex SymbolListMutex;
|
|
llvm::MemoryBufferRef DeviceImage;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
|
|
};
|
|
|
|
struct ol_symbol_impl_t {
|
|
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
|
|
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
|
|
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
|
|
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
|
|
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
|
|
ol_symbol_kind_t Kind;
|
|
llvm::StringRef Name;
|
|
};
|
|
|
|
namespace llvm {
|
|
namespace offload {
|
|
|
|
struct AllocInfo {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
void *Start;
|
|
// One byte past the end
|
|
void *End;
|
|
};
|
|
|
|
// Global shared state for liboffload
|
|
struct OffloadContext;
|
|
// This pointer is non-null if and only if the context is valid and fully
|
|
// initialized
|
|
static std::atomic<OffloadContext *> OffloadContextVal;
|
|
std::mutex OffloadContextValMutex;
|
|
struct OffloadContext {
|
|
OffloadContext(OffloadContext &) = delete;
|
|
OffloadContext(OffloadContext &&) = delete;
|
|
OffloadContext &operator=(OffloadContext &) = delete;
|
|
OffloadContext &operator=(OffloadContext &&) = delete;
|
|
|
|
bool TracingEnabled = false;
|
|
bool ValidationEnabled = true;
|
|
DenseMap<void *, AllocInfo> AllocInfoMap{};
|
|
std::mutex AllocInfoMapMutex{};
|
|
// Partitioned list of memory base addresses. Each element in this list is a
|
|
// key in AllocInfoMap
|
|
llvm::SmallVector<void *> AllocBases{};
|
|
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
|
|
ol_device_handle_t HostDevice;
|
|
size_t RefCount;
|
|
|
|
static OffloadContext &get() {
|
|
assert(OffloadContextVal);
|
|
return *OffloadContextVal;
|
|
}
|
|
};
|
|
|
|
// If the context is uninited, then we assume tracing is disabled
|
|
bool isTracingEnabled() {
|
|
return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
|
|
}
|
|
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
|
|
bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
|
|
|
|
template <typename HandleT> Error olDestroy(HandleT Handle) {
|
|
delete Handle;
|
|
return Error::success();
|
|
}
|
|
|
|
constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
|
|
if (Name == "amdgpu") {
|
|
return OL_PLATFORM_BACKEND_AMDGPU;
|
|
} else if (Name == "cuda") {
|
|
return OL_PLATFORM_BACKEND_CUDA;
|
|
} else {
|
|
return OL_PLATFORM_BACKEND_UNKNOWN;
|
|
}
|
|
}
|
|
|
|
// Every plugin exports this method to create an instance of the plugin type.
|
|
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
|
|
#include "Shared/Targets.def"
|
|
|
|
Error initPlugins(OffloadContext &Context) {
|
|
// Attempt to create an instance of each supported plugin.
|
|
#define PLUGIN_TARGET(Name) \
|
|
do { \
|
|
if (StringRef(#Name) != "host") \
|
|
Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
|
|
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
|
|
pluginNameToBackend(#Name))); \
|
|
} while (false);
|
|
#include "Shared/Targets.def"
|
|
|
|
// Add the special host device
|
|
auto &HostPlatform = Context.Platforms.emplace_back(
|
|
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
|
|
auto DevicesOrErr = HostPlatform->getDevices();
|
|
if (!DevicesOrErr)
|
|
return DevicesOrErr.takeError();
|
|
Context.HostDevice = DevicesOrErr
|
|
->emplace_back(std::make_unique<ol_device_impl_t>(
|
|
-1, nullptr, *HostPlatform, InfoTreeNode{}))
|
|
.get();
|
|
|
|
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
|
|
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
|
|
|
|
return Plugin::success();
|
|
}
|
|
|
|
Error olInit_impl() {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (isOffloadInitialized()) {
|
|
OffloadContext::get().RefCount++;
|
|
return Plugin::success();
|
|
}
|
|
|
|
// Use a temporary to ensure that entry points querying OffloadContextVal do
|
|
// not get a partially initialized context
|
|
auto *NewContext = new OffloadContext{};
|
|
Error InitResult = initPlugins(*NewContext);
|
|
OffloadContextVal.store(NewContext);
|
|
OffloadContext::get().RefCount++;
|
|
|
|
return InitResult;
|
|
}
|
|
|
|
Error olShutDown_impl() {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (--OffloadContext::get().RefCount != 0)
|
|
return Error::success();
|
|
|
|
llvm::Error Result = Error::success();
|
|
auto *OldContext = OffloadContextVal.exchange(nullptr);
|
|
|
|
for (auto &Platform : OldContext->Platforms) {
|
|
// Host plugin is nullptr and has no deinit
|
|
if (!Platform->Plugin || !Platform->Plugin->is_initialized())
|
|
continue;
|
|
|
|
if (auto Res = Platform->destroy())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Res));
|
|
}
|
|
|
|
delete OldContext;
|
|
return Result;
|
|
}
|
|
|
|
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
|
|
// Note that the plugin is potentially uninitialized here. It will need to be
|
|
// initialized once info is added that requires it to be initialized.
|
|
switch (PropName) {
|
|
case OL_PLATFORM_INFO_NAME:
|
|
return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
|
|
case OL_PLATFORM_INFO_VENDOR_NAME:
|
|
// TODO: Implement this
|
|
return Info.writeString("Unknown platform vendor");
|
|
case OL_PLATFORM_INFO_VERSION: {
|
|
return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
|
|
OL_VERSION_MINOR, OL_VERSION_PATCH)
|
|
.str());
|
|
}
|
|
case OL_PLATFORM_INFO_BACKEND: {
|
|
return Info.write<ol_platform_backend_t>(Platform->BackendType);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getPlatformInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
|
|
PropSizeRet);
|
|
}
|
|
|
|
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
assert(Device != OffloadContext::get().HostDevice);
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto makeError = [&](ErrorCode Code, StringRef Err) {
|
|
std::string ErrBuffer;
|
|
llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
|
|
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
|
|
};
|
|
|
|
// These are not implemented by the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_PLATFORM:
|
|
return Info.write<void *>(&Device->Platform);
|
|
|
|
case OL_DEVICE_INFO_TYPE:
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
|
|
|
|
case OL_DEVICE_INFO_SINGLE_FP_CONFIG:
|
|
case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: {
|
|
ol_device_fp_capability_flags_t flags{0};
|
|
flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_DENORM |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_FMA;
|
|
return Info.write(flags);
|
|
}
|
|
|
|
case OL_DEVICE_INFO_HALF_FP_CONFIG:
|
|
return Info.write<ol_device_fp_capability_flags_t>(0);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE:
|
|
return Info.write<uint32_t>(1);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF:
|
|
return Info.write<uint32_t>(0);
|
|
|
|
// None of the existing plugins specify a limit on a single allocation,
|
|
// so return the global memory size instead
|
|
case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
|
|
[[fallthrough]];
|
|
// AMD doesn't provide the global memory size (trivially) with the device info
|
|
// struct, so use the plugin interface
|
|
case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: {
|
|
uint64_t Mem;
|
|
if (auto Err = Device->Device->getDeviceMemorySize(Mem))
|
|
return Err;
|
|
return Info.write<uint64_t>(Mem);
|
|
} break;
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (PropName >= OL_DEVICE_INFO_LAST)
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getDeviceInfo enum '%i' is invalid", PropName);
|
|
|
|
auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
|
|
if (!EntryOpt)
|
|
return makeError(ErrorCode::UNIMPLEMENTED,
|
|
"plugin did not provide a response for this information");
|
|
auto Entry = *EntryOpt;
|
|
|
|
// Retrieve properties from the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_NAME:
|
|
case OL_DEVICE_INFO_PRODUCT_NAME:
|
|
case OL_DEVICE_INFO_VENDOR:
|
|
case OL_DEVICE_INFO_DRIVER_VERSION: {
|
|
// String values
|
|
if (!std::holds_alternative<std::string>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE:
|
|
case OL_DEVICE_INFO_VENDOR_ID:
|
|
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
|
|
case OL_DEVICE_INFO_ADDRESS_BITS:
|
|
case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY:
|
|
case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: {
|
|
// Uint32 values
|
|
if (!std::holds_alternative<uint64_t>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
auto Value = std::get<uint64_t>(Entry->Value);
|
|
if (Value > std::numeric_limits<uint32_t>::max())
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned out of range device info");
|
|
return Info.write(static_cast<uint32_t>(Value));
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: {
|
|
// {x, y, z} triples
|
|
ol_dimensions_t Out{0, 0, 0};
|
|
|
|
auto getField = [&](StringRef Name, uint32_t &Dest) {
|
|
if (auto F = Entry->get(Name)) {
|
|
if (!std::holds_alternative<size_t>((*F)->Value))
|
|
return makeError(
|
|
ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type for dimensions element");
|
|
Dest = std::get<size_t>((*F)->Value);
|
|
} else
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin didn't provide all values for dimensions");
|
|
return Plugin::success();
|
|
};
|
|
|
|
if (auto Res = getField("x", Out.x))
|
|
return Res;
|
|
if (auto Res = getField("y", Out.y))
|
|
return Res;
|
|
if (auto Res = getField("z", Out.z))
|
|
return Res;
|
|
|
|
return Info.write(Out);
|
|
}
|
|
|
|
default:
|
|
llvm_unreachable("Unimplemented device info");
|
|
}
|
|
}
|
|
|
|
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
assert(Device == OffloadContext::get().HostDevice);
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
constexpr auto uint32_max = std::numeric_limits<uint32_t>::max();
|
|
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_PLATFORM:
|
|
return Info.write<void *>(&Device->Platform);
|
|
case OL_DEVICE_INFO_TYPE:
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
|
|
case OL_DEVICE_INFO_NAME:
|
|
return Info.writeString("Virtual Host Device");
|
|
case OL_DEVICE_INFO_PRODUCT_NAME:
|
|
return Info.writeString("Virtual Host Device");
|
|
case OL_DEVICE_INFO_VENDOR:
|
|
return Info.writeString("Liboffload");
|
|
case OL_DEVICE_INFO_DRIVER_VERSION:
|
|
return Info.writeString(LLVM_VERSION_STRING);
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
|
|
return Info.write<uint32_t>(1);
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION:
|
|
return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1});
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE:
|
|
return Info.write<uint32_t>(uint32_max);
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
|
|
return Info.write<ol_dimensions_t>(
|
|
ol_dimensions_t{uint32_max, uint32_max, uint32_max});
|
|
case OL_DEVICE_INFO_VENDOR_ID:
|
|
return Info.write<uint32_t>(0);
|
|
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
|
|
return Info.write<uint32_t>(1);
|
|
case OL_DEVICE_INFO_SINGLE_FP_CONFIG:
|
|
case OL_DEVICE_INFO_DOUBLE_FP_CONFIG:
|
|
return Info.write<ol_device_fp_capability_flags_t>(
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | OL_DEVICE_FP_CAPABILITY_FLAG_FMA);
|
|
case OL_DEVICE_INFO_HALF_FP_CONFIG:
|
|
return Info.write<ol_device_fp_capability_flags_t>(0);
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE:
|
|
return Info.write<uint32_t>(1);
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF:
|
|
return Info.write<uint32_t>(0);
|
|
case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY:
|
|
case OL_DEVICE_INFO_MEMORY_CLOCK_RATE:
|
|
case OL_DEVICE_INFO_ADDRESS_BITS:
|
|
return Info.write<uint32_t>(std::numeric_limits<uintptr_t>::digits);
|
|
case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
|
|
case OL_DEVICE_INFO_GLOBAL_MEM_SIZE:
|
|
return Info.write<uint64_t>(0);
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getDeviceInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
if (Device == OffloadContext::get().HostDevice)
|
|
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t *PropSizeRet) {
|
|
if (Device == OffloadContext::get().HostDevice)
|
|
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
|
|
PropSizeRet);
|
|
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
|
|
for (auto &Platform : OffloadContext::get().Platforms) {
|
|
auto DevicesOrErr = Platform->getDevices();
|
|
if (!DevicesOrErr)
|
|
return DevicesOrErr.takeError();
|
|
for (auto &Device : *DevicesOrErr) {
|
|
if (!Callback(Device.get(), UserData)) {
|
|
return Error::success();
|
|
}
|
|
}
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
|
|
switch (Type) {
|
|
case OL_ALLOC_TYPE_DEVICE:
|
|
return TARGET_ALLOC_DEVICE;
|
|
case OL_ALLOC_TYPE_HOST:
|
|
return TARGET_ALLOC_HOST;
|
|
case OL_ALLOC_TYPE_MANAGED:
|
|
default:
|
|
return TARGET_ALLOC_SHARED;
|
|
}
|
|
}
|
|
|
|
constexpr size_t MAX_ALLOC_TRIES = 50;
|
|
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
|
|
size_t Size, void **AllocationOut) {
|
|
SmallVector<void *> Rejects;
|
|
|
|
// Repeat the allocation up to a certain amount of times. If it happens to
|
|
// already be allocated (e.g. by a device from another vendor) throw it away
|
|
// and try again.
|
|
for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) {
|
|
auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
|
|
convertOlToPluginAllocTy(Type));
|
|
if (!NewAlloc)
|
|
return NewAlloc.takeError();
|
|
|
|
void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
// Check that this memory region doesn't overlap another one
|
|
// That is, the start of this allocation needs to be after another
|
|
// allocation's end point, and the end of this allocation needs to be
|
|
// before the next one's start.
|
|
// `Gap` is the first alloc who ends after the new alloc's start point.
|
|
auto Gap =
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
|
|
// Success, no conflict
|
|
AllocInfoMap.insert_or_assign(
|
|
*NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
|
|
AllocBases.insert(
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
|
|
*NewAlloc);
|
|
*AllocationOut = *NewAlloc;
|
|
|
|
for (void *R : Rejects)
|
|
if (auto Err =
|
|
Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type)))
|
|
return Err;
|
|
return Error::success();
|
|
}
|
|
|
|
// To avoid the next attempt allocating the same memory we just freed, we
|
|
// hold onto it until we complete the allocation
|
|
Rejects.push_back(*NewAlloc);
|
|
}
|
|
}
|
|
|
|
// We've tried multiple times, and can't allocate a non-overlapping region.
|
|
return createOffloadError(ErrorCode::BACKEND_FAILURE,
|
|
"failed to allocate non-overlapping memory");
|
|
}
|
|
|
|
Error olMemFree_impl(void *Address) {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
if (!OffloadContext::get().AllocInfoMap.contains(Address))
|
|
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
|
|
"address is not a known allocation");
|
|
|
|
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
|
|
Device = AllocInfo.Device;
|
|
Type = AllocInfo.Type;
|
|
OffloadContext::get().AllocInfoMap.erase(Address);
|
|
|
|
auto &Bases = OffloadContext::get().AllocBases;
|
|
Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
|
|
}
|
|
|
|
if (auto Res =
|
|
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue,
|
|
size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
const AllocInfo *Alloc = nullptr;
|
|
if (AllocInfoMap.contains(Ptr)) {
|
|
// Fast case, we have been given the base pointer directly
|
|
Alloc = &AllocInfoMap.at(Ptr);
|
|
} else {
|
|
// Slower case, we need to look up the base pointer first
|
|
// Find the first memory allocation whose end is after the target pointer,
|
|
// and then check to see if it is in range
|
|
auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start)
|
|
return Plugin::error(ErrorCode::NOT_FOUND,
|
|
"allocated memory information not found");
|
|
Alloc = &AllocInfoMap.at(*Loc);
|
|
}
|
|
|
|
switch (PropName) {
|
|
case OL_MEM_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Alloc->Device);
|
|
case OL_MEM_INFO_BASE:
|
|
return Info.write<void *>(Alloc->Start);
|
|
case OL_MEM_INFO_SIZE:
|
|
return Info.write<size_t>(static_cast<char *>(Alloc->End) -
|
|
static_cast<char *>(Alloc->Start));
|
|
case OL_MEM_INFO_TYPE:
|
|
return Info.write<ol_alloc_type_t>(Alloc->Type);
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetMemInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr);
|
|
}
|
|
|
|
Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
|
|
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
|
|
|
|
auto OutstandingQueue = Device->getOutstandingQueue();
|
|
if (OutstandingQueue) {
|
|
// The queue is empty, but we still need to sync it to release any temporary
|
|
// memory allocations or do other cleanup.
|
|
if (auto Err =
|
|
Device->Device->synchronize(OutstandingQueue, /*Release=*/false))
|
|
return Err;
|
|
CreatedQueue->AsyncInfo = OutstandingQueue;
|
|
} else if (auto Err =
|
|
Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) {
|
|
return Err;
|
|
}
|
|
|
|
*Queue = CreatedQueue.release();
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyQueue_impl(ol_queue_handle_t Queue) {
|
|
auto *Device = Queue->Device;
|
|
// This is safe; as soon as olDestroyQueue is called it is not possible to add
|
|
// any more work to the queue, so if it's finished now it will remain finished
|
|
// forever.
|
|
auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
|
|
if (!*Res) {
|
|
// The queue is complete, so sync it and throw it back into the pool.
|
|
if (auto Err = Device->Device->synchronize(Queue->AsyncInfo,
|
|
/*Release=*/true))
|
|
return Err;
|
|
} else {
|
|
// The queue still has outstanding work. Store it so we can check it later.
|
|
std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex);
|
|
Device->OutstandingQueues.push_back(Queue->AsyncInfo);
|
|
}
|
|
|
|
return olDestroy(Queue);
|
|
}
|
|
|
|
Error olSyncQueue_impl(ol_queue_handle_t Queue) {
|
|
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
|
// on it, but we have nothing to synchronize in that situation anyway.
|
|
if (Queue->AsyncInfo->Queue) {
|
|
// We don't need to release the queue and we would like the ability for
|
|
// other offload threads to submit work concurrently, so pass "false" here
|
|
// so we don't release the underlying queue object.
|
|
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
|
|
size_t NumEvents) {
|
|
auto *Device = Queue->Device->Device;
|
|
|
|
for (size_t I = 0; I < NumEvents; I++) {
|
|
auto *Event = Events[I];
|
|
|
|
if (!Event)
|
|
return Plugin::error(ErrorCode::INVALID_NULL_HANDLE,
|
|
"olWaitEvents asked to wait on a NULL event");
|
|
|
|
// Do nothing if the event is for this queue or the event is always complete
|
|
if (Event->QueueId == Queue->Id || !Event->EventInfo)
|
|
continue;
|
|
|
|
if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
|
|
ol_queue_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
switch (PropName) {
|
|
case OL_QUEUE_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Queue->Device);
|
|
case OL_QUEUE_INFO_EMPTY: {
|
|
auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (auto Err = Pending.takeError())
|
|
return Err;
|
|
return Info.write<bool>(!*Pending);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetQueueInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olSyncEvent_impl(ol_event_handle_t Event) {
|
|
// No event info means that this event was complete on creation
|
|
if (!Event->EventInfo)
|
|
return Plugin::success();
|
|
|
|
if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyEvent_impl(ol_event_handle_t Event) {
|
|
if (Event->EventInfo)
|
|
if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return olDestroy(Event);
|
|
}
|
|
|
|
Error olGetEventInfoImplDetail(ol_event_handle_t Event,
|
|
ol_event_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
auto Queue = Event->Queue;
|
|
|
|
switch (PropName) {
|
|
case OL_EVENT_INFO_QUEUE:
|
|
return Info.write<ol_queue_handle_t>(Queue);
|
|
case OL_EVENT_INFO_IS_COMPLETE: {
|
|
// No event info means that this event was complete on creation
|
|
if (!Event->EventInfo)
|
|
return Info.write<bool>(true);
|
|
|
|
auto Res = Queue->Device->Device->isEventComplete(Event->EventInfo,
|
|
Queue->AsyncInfo);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
return Info.write<bool>(*Res);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetEventInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
|
|
auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (auto Err = Pending.takeError())
|
|
return Err;
|
|
|
|
*EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue);
|
|
if (!*Pending)
|
|
// Queue is empty, don't record an event and consider the event always
|
|
// complete
|
|
return Plugin::success();
|
|
|
|
if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo))
|
|
return Res;
|
|
|
|
if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo,
|
|
Queue->AsyncInfo))
|
|
return Res;
|
|
|
|
return Plugin::success();
|
|
}
|
|
|
|
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
|
|
ol_device_handle_t DstDevice, const void *SrcPtr,
|
|
ol_device_handle_t SrcDevice, size_t Size) {
|
|
auto Host = OffloadContext::get().HostDevice;
|
|
if (DstDevice == Host && SrcDevice == Host) {
|
|
if (!Queue) {
|
|
std::memcpy(DstPtr, SrcPtr, Size);
|
|
return Error::success();
|
|
} else {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_ARGUMENT,
|
|
"ane of DstDevice and SrcDevice must be a non-host device if "
|
|
"queue is specified");
|
|
}
|
|
}
|
|
|
|
// If no queue is given the memcpy will be synchronous
|
|
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
|
|
if (DstDevice == Host) {
|
|
if (auto Res =
|
|
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else if (SrcDevice == Host) {
|
|
if (auto Res =
|
|
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else {
|
|
if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
|
|
DstPtr, Size, QueueImpl))
|
|
return Res;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
|
|
const void *PatternPtr, size_t FillSize) {
|
|
return Queue->Device->Device->dataFill(Ptr, PatternPtr, PatternSize, FillSize,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, ol_program_handle_t *Program) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
Expected<plugin::DeviceImageTy *> Res =
|
|
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
assert(*Res && "loadBinary returned nullptr");
|
|
|
|
*Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer());
|
|
return Error::success();
|
|
}
|
|
|
|
Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, bool *IsValid) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
*IsValid = Device->Device ? Device->Device->Plugin.isDeviceCompatible(
|
|
Device->Device->getDeviceId(), Buffer)
|
|
: false;
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyProgram_impl(ol_program_handle_t Program) {
|
|
auto &Device = Program->Image->getDevice();
|
|
if (auto Err = Device.unloadBinary(Program->Image))
|
|
return Err;
|
|
|
|
auto &LoadedImages = Device.LoadedImages;
|
|
LoadedImages.erase(
|
|
std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image));
|
|
|
|
return olDestroy(Program);
|
|
}
|
|
|
|
Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel,
|
|
size_t DynamicMemSize,
|
|
size_t *GroupSize) {
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
|
|
auto Res = KernelImpl->maxGroupSize(*Device->Device, DynamicMemSize);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
|
|
*GroupSize = *Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel, const void *ArgumentsData,
|
|
size_t ArgumentsSize,
|
|
const ol_kernel_launch_size_args_t *LaunchSizeArgs) {
|
|
auto *DeviceImpl = Device->Device;
|
|
if (Queue && Device != Queue->Device) {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_DEVICE,
|
|
"device specified does not match the device of the given queue");
|
|
}
|
|
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
|
|
auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
|
|
KernelArgsTy LaunchArgs{};
|
|
LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x;
|
|
LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y;
|
|
LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z;
|
|
LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x;
|
|
LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y;
|
|
LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
|
|
LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
|
|
|
|
KernelLaunchParamsTy Params;
|
|
Params.Data = const_cast<void *>(ArgumentsData);
|
|
Params.Size = ArgumentsSize;
|
|
LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
|
|
// Don't do anything with pointer indirection; use arg data as-is
|
|
LaunchArgs.Flags.IsCUDA = true;
|
|
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
|
|
LaunchArgs, AsyncInfoWrapper);
|
|
|
|
AsyncInfoWrapper.finalize(Err);
|
|
if (Err)
|
|
return Err;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
|
|
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
|
auto &Device = Program->Image->getDevice();
|
|
|
|
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
|
|
|
|
switch (Kind) {
|
|
case OL_SYMBOL_KIND_KERNEL: {
|
|
auto &Kernel = Program->KernelSymbols[Name];
|
|
if (!Kernel) {
|
|
auto KernelImpl = Device.constructKernel(Name);
|
|
if (!KernelImpl)
|
|
return KernelImpl.takeError();
|
|
|
|
if (auto Err = KernelImpl->init(Device, *Program->Image))
|
|
return Err;
|
|
|
|
Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(),
|
|
&*KernelImpl);
|
|
}
|
|
|
|
*Symbol = Kernel.get();
|
|
return Error::success();
|
|
}
|
|
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
|
|
auto &Global = Program->GlobalSymbols[Name];
|
|
if (!Global) {
|
|
GlobalTy GlobalObj{Name};
|
|
if (auto Res =
|
|
Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
|
|
Device, *Program->Image, GlobalObj))
|
|
return Res;
|
|
|
|
Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(),
|
|
std::move(GlobalObj));
|
|
}
|
|
|
|
*Symbol = Global.get();
|
|
return Error::success();
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getSymbol kind enum '%i' is invalid", Kind);
|
|
}
|
|
}
|
|
|
|
Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto CheckKind = [&](ol_symbol_kind_t Required) {
|
|
if (Symbol->Kind != Required) {
|
|
std::string ErrBuffer;
|
|
llvm::raw_string_ostream(ErrBuffer)
|
|
<< PropName << ": Expected a symbol of Kind " << Required
|
|
<< " but given a symbol of Kind " << Symbol->Kind;
|
|
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
|
|
}
|
|
return Plugin::success();
|
|
};
|
|
|
|
switch (PropName) {
|
|
case OL_SYMBOL_INFO_KIND:
|
|
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetSymbolInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t *PropSizeRet) {
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
|
|
ol_host_function_cb_t Callback,
|
|
void *UserData) {
|
|
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
} // namespace offload
|
|
} // namespace llvm
|