This patch adds `olGetEventElapsedTime` to the new LLVM Offload API, as requested in [#185728](https://github.com/llvm/llvm-project/issues/185728), and adds the corresponding support in `plugins-nextgen`. A main motivation for this change is to make it possible to measure the elapsed time of work submitted to a queue, especially kernel launches. This is relevant to the intended use of the new Offload API for microbenchmarking GPU libc math functions. ### Summary The new API returns the elapsed time, in milliseconds, between two events on the same device. To support the common pattern `create start event → enqueue kernel → create end event → sync end event → get elapsed time`, `olCreateEvent` now always creates and records a backend event through the device interface. For backends that materialize real event state, this gives the event concrete backend state that can be used for elapsed-time measurement. For backends that do not materialize backend event state, `EventInfo` may still remain null and existing event operations continue to treat such events as trivially complete. Previously, an event created on an empty queue could be represented only as a logical event. That representation was sufficient for sync and completion queries, but it was not suitable for elapsed-time measurement because there was no backend event state to timestamp. The new behavior preserves the meaning of completion of prior work while also allowing backends with timing support to attach real event state. ### Changes in `plugins-nextgen` #### Common interface Add elapsed-time support to the common device and plugin interfaces: * `GenericPluginTy::get_event_elapsed_time` * `GenericDeviceTy::getEventElapsedTime` * `GenericDeviceTy::getEventElapsedTimeImpl` #### AMDGPU * Add the required ROCr declarations and wrappers. * Enable queue profiling at queue creation time. * Record events by enqueuing a real barrier marker packet on the stream. * Retain the timing signal needed to query the recorded marker later. * Implement `getEventElapsedTimeImpl` using `hsa_amd_profiling_get_dispatch_time`, converting the result to milliseconds with `HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY`. This follows the ROCm/HIP approach of enabling queue profiling at HSA queue creation time, while keeping the AMDGPU queue path simpler than the lazy-enable alternative discussed during review. #### CUDA * Add the required CUDA driver declarations and wrappers. * Implement `getEventElapsedTimeImpl` with `cuEventElapsedTime`. #### Host * Add `getEventElapsedTimeImpl` that stores `0.0f` in the output pointer, when present, and returns success. Reason: the host plugin does not materialize backend event state and already treats event operations as trivially successful. Returning `0.0f` preserves that model without introducing a new failure mode. #### Level Zero * Add `getEventElapsedTimeImpl`, but leave it unimplemented. Reason: the Level Zero plugin currently does not provide standalone backend event support for this event model. For example, `waitEventImpl` / `syncEventImpl` are still unimplemented there. --------- Signed-off-by: Leandro Augusto Lacerda Campos <leandrolcampos@yahoo.com.br> Signed-off-by: Leandro A. Lacerda Campos <leandrolcampos@yahoo.com.br>
1231 lines
44 KiB
C++
1231 lines
44 KiB
C++
//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This contains the definitions of the new LLVM/Offload API entry points. See
|
|
// new-api/API/README.md for more information.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "OffloadImpl.hpp"
|
|
#include "Helpers.hpp"
|
|
#include "OffloadPrint.hpp"
|
|
#include "PluginManager.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <OffloadAPI.h>
|
|
|
|
#include <cstdint>
|
|
#include <mutex>
|
|
|
|
// TODO: Some plugins expect to be linked into libomptarget which defines these
|
|
// symbols to implement ompt callbacks. The least invasive workaround here is to
|
|
// define them in libLLVMOffload as false/null so they are never used. In future
|
|
// it would be better to allow the plugins to implement callbacks without
|
|
// pulling in details from libomptarget.
|
|
#ifdef OMPT_SUPPORT
|
|
namespace llvm::omp::target {
|
|
namespace ompt {
|
|
bool Initialized = false;
|
|
ompt_get_callback_t lookupCallbackByCode = nullptr;
|
|
ompt_function_lookup_t lookupCallbackByName = nullptr;
|
|
} // namespace ompt
|
|
} // namespace llvm::omp::target
|
|
#endif
|
|
|
|
using namespace llvm::omp::target;
|
|
using namespace llvm::omp::target::plugin;
|
|
using namespace error;
|
|
|
|
struct ol_platform_impl_t {
|
|
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
|
|
ol_platform_backend_t BackendType)
|
|
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
|
|
ol_platform_backend_t BackendType;
|
|
|
|
/// Complete all pending work for this platform and perform any needed
|
|
/// cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this platform handle.
|
|
llvm::Error destroy();
|
|
|
|
/// Initialize the associated plugin and devices.
|
|
llvm::Error init();
|
|
|
|
/// Direct access to the plugin, may be uninitialized if accessed here.
|
|
std::unique_ptr<GenericPluginTy> Plugin;
|
|
|
|
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
|
|
};
|
|
|
|
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
|
|
// we add some additional data here for now to avoid churn in the plugin
|
|
// interface.
|
|
struct ol_device_impl_t {
|
|
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
|
|
ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
|
|
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
|
|
Info(std::forward<InfoTreeNode>(DevInfo)) {}
|
|
|
|
~ol_device_impl_t() {
|
|
assert(!OutstandingQueues.size() &&
|
|
"Device object dropped with outstanding queues");
|
|
}
|
|
|
|
int DeviceNum;
|
|
GenericDeviceTy *Device;
|
|
ol_platform_impl_t &Platform;
|
|
InfoTreeNode Info;
|
|
|
|
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
|
|
std::mutex OutstandingQueuesMutex;
|
|
|
|
/// If the device has any outstanding queues that are now complete, remove it
|
|
/// from the list and return it.
|
|
///
|
|
/// Queues may be added to the outstanding queue list by olDestroyQueue if
|
|
/// they are destroyed but not completed.
|
|
__tgt_async_info *getOutstandingQueue() {
|
|
// Not locking the `size()` access is fine here - In the worst case we
|
|
// either miss a queue that exists or loop through an empty array after
|
|
// taking the lock. Both are sub-optimal but not that bad.
|
|
if (OutstandingQueues.size()) {
|
|
std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex);
|
|
|
|
// As queues are pulled and popped from this list, longer running queues
|
|
// naturally bubble to the start of the array. Hence looping backwards.
|
|
for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend();
|
|
Q++) {
|
|
if (!Device->hasPendingWork(*Q)) {
|
|
auto OutstandingQueue = *Q;
|
|
*Q = OutstandingQueues.back();
|
|
OutstandingQueues.pop_back();
|
|
return OutstandingQueue;
|
|
}
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Complete all pending work for this device and perform any needed cleanup.
|
|
///
|
|
/// After calling this function, no liboffload functions should be called with
|
|
/// this device handle.
|
|
llvm::Error destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto Q : OutstandingQueues)
|
|
if (auto Err = Device->synchronize(Q, /*Release=*/true))
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
OutstandingQueues.clear();
|
|
return Result;
|
|
}
|
|
};
|
|
|
|
llvm::Error ol_platform_impl_t::destroy() {
|
|
llvm::Error Result = Plugin::success();
|
|
for (auto &D : Devices)
|
|
if (auto Err = D->destroy())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Err));
|
|
|
|
if (auto Res = Plugin->deinit())
|
|
Result = llvm::joinErrors(std::move(Result), std::move(Res));
|
|
|
|
return Result;
|
|
}
|
|
|
|
llvm::Error ol_platform_impl_t::init() {
|
|
if (!Plugin)
|
|
return llvm::Error::success();
|
|
|
|
if (llvm::Error Err = Plugin->init())
|
|
return Err;
|
|
|
|
for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
|
|
if (llvm::Error Err = Plugin->initDevice(Id))
|
|
return Err;
|
|
|
|
GenericDeviceTy *Device = &Plugin->getDevice(Id);
|
|
llvm::Expected<InfoTreeNode> Info = Device->obtainInfo();
|
|
if (llvm::Error Err = Info.takeError())
|
|
return Err;
|
|
Devices.emplace_back(std::make_unique<ol_device_impl_t>(Id, Device, *this,
|
|
std::move(*Info)));
|
|
}
|
|
|
|
return llvm::Error::success();
|
|
}
|
|
|
|
struct ol_queue_impl_t {
|
|
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
|
|
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
|
|
__tgt_async_info *AsyncInfo;
|
|
ol_device_handle_t Device;
|
|
// A unique identifier for the queue
|
|
size_t Id;
|
|
static std::atomic<size_t> IdCounter;
|
|
};
|
|
std::atomic<size_t> ol_queue_impl_t::IdCounter(0);
|
|
|
|
struct ol_event_impl_t {
|
|
ol_event_impl_t(void *EventInfo, ol_device_handle_t Device,
|
|
ol_queue_handle_t Queue)
|
|
: EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
|
|
}
|
|
// Opaque backend-specific event state. This is expected to be non-null for
|
|
// backends that materialize real events.
|
|
void *EventInfo;
|
|
ol_device_handle_t Device;
|
|
size_t QueueId;
|
|
// Events may outlive the queue - don't assume this is always valid.
|
|
// It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
|
|
// for queue equality instead.
|
|
ol_queue_handle_t Queue;
|
|
};
|
|
|
|
struct ol_program_impl_t {
|
|
ol_program_impl_t(plugin::DeviceImageTy *Image,
|
|
llvm::MemoryBufferRef DeviceImage)
|
|
: Image(Image), DeviceImage(DeviceImage) {}
|
|
plugin::DeviceImageTy *Image;
|
|
std::mutex SymbolListMutex;
|
|
llvm::MemoryBufferRef DeviceImage;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
|
|
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
|
|
};
|
|
|
|
struct ol_symbol_impl_t {
|
|
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
|
|
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
|
|
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
|
|
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
|
|
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
|
|
ol_symbol_kind_t Kind;
|
|
llvm::StringRef Name;
|
|
};
|
|
|
|
namespace llvm {
|
|
namespace offload {
|
|
|
|
struct AllocInfo {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
void *Start;
|
|
// One byte past the end
|
|
void *End;
|
|
};
|
|
|
|
// Global shared state for liboffload
|
|
struct OffloadContext;
|
|
// This pointer is non-null if and only if the context is valid and fully
|
|
// initialized
|
|
static std::atomic<OffloadContext *> OffloadContextVal;
|
|
std::mutex OffloadContextValMutex;
|
|
struct OffloadContext {
|
|
OffloadContext(OffloadContext &) = delete;
|
|
OffloadContext(OffloadContext &&) = delete;
|
|
OffloadContext &operator=(OffloadContext &) = delete;
|
|
OffloadContext &operator=(OffloadContext &&) = delete;
|
|
|
|
bool TracingEnabled = false;
|
|
bool ValidationEnabled = true;
|
|
DenseMap<void *, AllocInfo> AllocInfoMap{};
|
|
std::mutex AllocInfoMapMutex{};
|
|
// Partitioned list of memory base addresses. Each element in this list is a
|
|
// key in AllocInfoMap
|
|
SmallVector<void *> AllocBases{};
|
|
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
|
|
size_t RefCount;
|
|
|
|
static OffloadContext &get() {
|
|
assert(OffloadContextVal);
|
|
return *OffloadContextVal;
|
|
}
|
|
};
|
|
|
|
// If the context is uninited, then we assume tracing is disabled
|
|
bool isTracingEnabled() {
|
|
return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
|
|
}
|
|
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
|
|
bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
|
|
|
|
template <typename HandleT> Error olDestroy(HandleT Handle) {
|
|
delete Handle;
|
|
return Error::success();
|
|
}
|
|
|
|
constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
|
|
if (Name == "amdgpu") {
|
|
return OL_PLATFORM_BACKEND_AMDGPU;
|
|
} else if (Name == "cuda") {
|
|
return OL_PLATFORM_BACKEND_CUDA;
|
|
} else if (Name == "host") {
|
|
return OL_PLATFORM_BACKEND_HOST;
|
|
} else if (Name == "level_zero") {
|
|
return OL_PLATFORM_BACKEND_LEVEL_ZERO;
|
|
} else {
|
|
return OL_PLATFORM_BACKEND_UNKNOWN;
|
|
}
|
|
}
|
|
|
|
// Every plugin exports this method to create an instance of the plugin type.
|
|
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
|
|
#include "Shared/Targets.def"
|
|
|
|
Error initPlugins(OffloadContext &Context, const ol_init_args_t *InitArgs) {
|
|
SmallSet<ol_platform_backend_t, 0> Requested;
|
|
if (InitArgs && InitArgs->NumPlatforms > 0)
|
|
for (uint32_t I = 0; I < InitArgs->NumPlatforms; I++)
|
|
Requested.insert(InitArgs->Platforms[I]);
|
|
|
|
// Attempt to create an instance of each supported plugin, skipping
|
|
// unrequested backends. The host plugin is always created.
|
|
#define PLUGIN_TARGET(Name) \
|
|
do { \
|
|
auto Backend = pluginNameToBackend(#Name); \
|
|
if (Requested.empty() || Backend == OL_PLATFORM_BACKEND_HOST || \
|
|
Requested.contains(Backend)) { \
|
|
Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
|
|
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), Backend)); \
|
|
} \
|
|
} while (false);
|
|
#include "Shared/Targets.def"
|
|
|
|
// Eagerly initialize all of the plugins and devices. We need to make sure
|
|
// that the platform is initialized at a consistent point to maintain the
|
|
// expected teardown order in the vendor libraries.
|
|
for (auto &Platform : Context.Platforms) {
|
|
if (Error Err = Platform->init())
|
|
return Err;
|
|
}
|
|
|
|
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
|
|
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
|
|
|
|
return Plugin::success();
|
|
}
|
|
|
|
Error olInit_impl(const ol_init_args_t *InitArgs) {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (isOffloadInitialized()) {
|
|
OffloadContext::get().RefCount++;
|
|
return Plugin::success();
|
|
}
|
|
|
|
if (InitArgs) {
|
|
if (InitArgs->Size < sizeof(ol_init_args_t))
|
|
return createOffloadError(ErrorCode::INVALID_SIZE,
|
|
"ol_init_args_t Size field is too small");
|
|
if (InitArgs->NumPlatforms > 0 && !InitArgs->Platforms)
|
|
return createOffloadError(ErrorCode::INVALID_NULL_POINTER,
|
|
"NumPlatforms > 0 but Platforms is null");
|
|
}
|
|
|
|
// Use a temporary to ensure that entry points querying OffloadContextVal do
|
|
// not get a partially initialized context
|
|
auto *NewContext = new OffloadContext{};
|
|
Error InitResult = initPlugins(*NewContext, InitArgs);
|
|
OffloadContextVal.store(NewContext);
|
|
OffloadContext::get().RefCount++;
|
|
|
|
return InitResult;
|
|
}
|
|
|
|
Error olShutDown_impl() {
|
|
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
|
|
|
|
if (--OffloadContext::get().RefCount != 0)
|
|
return Error::success();
|
|
|
|
Error Result = Error::success();
|
|
auto *OldContext = OffloadContextVal.exchange(nullptr);
|
|
|
|
for (auto &Platform : OldContext->Platforms) {
|
|
// Host plugin is nullptr and has no deinit
|
|
if (!Platform->Plugin || !Platform->Plugin->is_initialized())
|
|
continue;
|
|
|
|
if (auto Res = Platform->destroy())
|
|
Result = joinErrors(std::move(Result), std::move(Res));
|
|
}
|
|
|
|
delete OldContext;
|
|
return Result;
|
|
}
|
|
|
|
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
// Note that the plugin is potentially uninitialized here. It will need to be
|
|
// initialized once info is added that requires it to be initialized.
|
|
switch (PropName) {
|
|
case OL_PLATFORM_INFO_NAME:
|
|
return Info.writeString(Platform->Plugin->getName());
|
|
case OL_PLATFORM_INFO_VENDOR_NAME:
|
|
// TODO: Implement this
|
|
return Info.writeString("Unknown platform vendor");
|
|
case OL_PLATFORM_INFO_VERSION: {
|
|
return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
|
|
OL_VERSION_MINOR, OL_VERSION_PATCH)
|
|
.str());
|
|
}
|
|
case OL_PLATFORM_INFO_BACKEND: {
|
|
return Info.write<ol_platform_backend_t>(Platform->BackendType);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getPlatformInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName, size_t PropSize,
|
|
void *PropValue) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
|
|
ol_platform_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
|
|
PropSizeRet);
|
|
}
|
|
|
|
Error olPlatformRegisterRPCCallback_impl(ol_platform_handle_t Platform,
|
|
ol_platform_rpc_cb_t Callback) {
|
|
Platform->Plugin->getRPCServer().registerCallback(Callback);
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto makeError = [&](ErrorCode Code, StringRef Err) {
|
|
std::string ErrBuffer;
|
|
raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
|
|
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
|
|
};
|
|
bool IsHost = Device->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
// These are not implemented by the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_PLATFORM:
|
|
return Info.write<void *>(&Device->Platform);
|
|
|
|
case OL_DEVICE_INFO_TYPE:
|
|
if (IsHost)
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
|
|
else
|
|
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
|
|
|
|
case OL_DEVICE_INFO_SINGLE_FP_CONFIG:
|
|
case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: {
|
|
ol_device_fp_capability_flags_t flags{0};
|
|
flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_DENORM |
|
|
OL_DEVICE_FP_CAPABILITY_FLAG_FMA;
|
|
return Info.write(flags);
|
|
}
|
|
|
|
case OL_DEVICE_INFO_HALF_FP_CONFIG:
|
|
return Info.write<ol_device_fp_capability_flags_t>(0);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT:
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE:
|
|
return Info.write<uint32_t>(1);
|
|
|
|
case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF:
|
|
return Info.write<uint32_t>(0);
|
|
|
|
// None of the existing plugins specify a limit on a single allocation,
|
|
// so return the global memory size instead
|
|
case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
|
|
[[fallthrough]];
|
|
// AMD doesn't provide the global memory size (trivially) with the device info
|
|
// struct, so use the plugin interface
|
|
case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: {
|
|
uint64_t Mem;
|
|
if (auto Err = Device->Device->getDeviceMemorySize(Mem))
|
|
return Err;
|
|
return Info.write<uint64_t>(Mem);
|
|
} break;
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
if (PropName >= OL_DEVICE_INFO_LAST)
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getDeviceInfo enum '%i' is invalid", PropName);
|
|
|
|
auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
|
|
if (!EntryOpt)
|
|
return makeError(ErrorCode::UNIMPLEMENTED,
|
|
"plugin did not provide a response for this information");
|
|
auto Entry = *EntryOpt;
|
|
|
|
// Retrieve properties from the plugin interface
|
|
switch (PropName) {
|
|
case OL_DEVICE_INFO_NAME:
|
|
case OL_DEVICE_INFO_PRODUCT_NAME:
|
|
case OL_DEVICE_INFO_UID:
|
|
case OL_DEVICE_INFO_VENDOR:
|
|
case OL_DEVICE_INFO_DRIVER_VERSION: {
|
|
// String values
|
|
if (!std::holds_alternative<std::string>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE:
|
|
case OL_DEVICE_INFO_VENDOR_ID:
|
|
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
|
|
case OL_DEVICE_INFO_ADDRESS_BITS:
|
|
case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY:
|
|
case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: {
|
|
// Uint32 values
|
|
if (!std::holds_alternative<uint64_t>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
auto Value = std::get<uint64_t>(Entry->Value);
|
|
if (Value > std::numeric_limits<uint32_t>::max())
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned out of range device info");
|
|
return Info.write(static_cast<uint32_t>(Value));
|
|
}
|
|
|
|
case OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE: {
|
|
if (!std::holds_alternative<uint64_t>(Entry->Value))
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type");
|
|
return Info.write(std::get<uint64_t>(Entry->Value));
|
|
}
|
|
|
|
case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
|
|
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: {
|
|
// {x, y, z} triples
|
|
ol_dimensions_t Out{0, 0, 0};
|
|
|
|
auto getField = [&](StringRef Name, uint32_t &Dest) {
|
|
if (auto F = Entry->get(Name)) {
|
|
if (!std::holds_alternative<uint64_t>((*F)->Value))
|
|
return makeError(
|
|
ErrorCode::BACKEND_FAILURE,
|
|
"plugin returned incorrect type for dimensions element");
|
|
Dest = std::get<uint64_t>((*F)->Value);
|
|
} else
|
|
return makeError(ErrorCode::BACKEND_FAILURE,
|
|
"plugin didn't provide all values for dimensions");
|
|
return Plugin::success();
|
|
};
|
|
|
|
if (auto Res = getField("x", Out.x))
|
|
return Res;
|
|
if (auto Res = getField("y", Out.y))
|
|
return Res;
|
|
if (auto Res = getField("z", Out.z))
|
|
return Res;
|
|
|
|
return Info.write(Out);
|
|
}
|
|
|
|
default:
|
|
llvm_unreachable("Unimplemented device info");
|
|
}
|
|
}
|
|
|
|
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
|
|
ol_device_info_t PropName, size_t *PropSizeRet) {
|
|
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
|
|
for (auto &Platform : OffloadContext::get().Platforms) {
|
|
for (auto &Device : Platform->Devices) {
|
|
if (!Callback(Device.get(), UserData)) {
|
|
return Error::success();
|
|
}
|
|
}
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
|
|
switch (Type) {
|
|
case OL_ALLOC_TYPE_DEVICE:
|
|
return TARGET_ALLOC_DEVICE;
|
|
case OL_ALLOC_TYPE_HOST:
|
|
return TARGET_ALLOC_HOST;
|
|
case OL_ALLOC_TYPE_MANAGED:
|
|
default:
|
|
return TARGET_ALLOC_SHARED;
|
|
}
|
|
}
|
|
|
|
constexpr size_t MAX_ALLOC_TRIES = 50;
|
|
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
|
|
size_t Size, void **AllocationOut) {
|
|
SmallVector<void *> Rejects;
|
|
|
|
// Repeat the allocation up to a certain amount of times. If it happens to
|
|
// already be allocated (e.g. by a device from another vendor) throw it away
|
|
// and try again.
|
|
for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) {
|
|
auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
|
|
convertOlToPluginAllocTy(Type));
|
|
if (!NewAlloc)
|
|
return NewAlloc.takeError();
|
|
|
|
void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
// Check that this memory region doesn't overlap another one
|
|
// That is, the start of this allocation needs to be after another
|
|
// allocation's end point, and the end of this allocation needs to be
|
|
// before the next one's start.
|
|
// `Gap` is the first alloc who ends after the new alloc's start point.
|
|
auto Gap =
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
|
|
// Success, no conflict
|
|
AllocInfoMap.insert_or_assign(
|
|
*NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
|
|
AllocBases.insert(
|
|
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
|
|
*NewAlloc);
|
|
*AllocationOut = *NewAlloc;
|
|
|
|
for (void *R : Rejects)
|
|
if (auto Err =
|
|
Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type)))
|
|
return Err;
|
|
return Error::success();
|
|
}
|
|
|
|
// To avoid the next attempt allocating the same memory we just freed, we
|
|
// hold onto it until we complete the allocation
|
|
Rejects.push_back(*NewAlloc);
|
|
}
|
|
}
|
|
|
|
// We've tried multiple times, and can't allocate a non-overlapping region.
|
|
return createOffloadError(ErrorCode::BACKEND_FAILURE,
|
|
"failed to allocate non-overlapping memory");
|
|
}
|
|
|
|
Error olMemFree_impl(void *Address) {
|
|
ol_device_handle_t Device;
|
|
ol_alloc_type_t Type;
|
|
{
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
if (!OffloadContext::get().AllocInfoMap.contains(Address))
|
|
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
|
|
"address is not a known allocation");
|
|
|
|
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
|
|
Device = AllocInfo.Device;
|
|
Type = AllocInfo.Type;
|
|
OffloadContext::get().AllocInfoMap.erase(Address);
|
|
|
|
auto &Bases = OffloadContext::get().AllocBases;
|
|
Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
|
|
}
|
|
|
|
if (auto Res =
|
|
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue,
|
|
size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
|
|
|
|
auto &AllocBases = OffloadContext::get().AllocBases;
|
|
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
|
|
const AllocInfo *Alloc = nullptr;
|
|
if (AllocInfoMap.contains(Ptr)) {
|
|
// Fast case, we have been given the base pointer directly
|
|
Alloc = &AllocInfoMap.at(Ptr);
|
|
} else {
|
|
// Slower case, we need to look up the base pointer first
|
|
// Find the first memory allocation whose end is after the target pointer,
|
|
// and then check to see if it is in range
|
|
auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr,
|
|
[&](const void *Iter, const void *Val) {
|
|
return AllocInfoMap.at(Iter).End <= Val;
|
|
});
|
|
if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start)
|
|
return Plugin::error(ErrorCode::NOT_FOUND,
|
|
"allocated memory information not found");
|
|
Alloc = &AllocInfoMap.at(*Loc);
|
|
}
|
|
|
|
switch (PropName) {
|
|
case OL_MEM_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Alloc->Device);
|
|
case OL_MEM_INFO_BASE:
|
|
return Info.write<void *>(Alloc->Start);
|
|
case OL_MEM_INFO_SIZE:
|
|
return Info.write<size_t>(static_cast<char *>(Alloc->End) -
|
|
static_cast<char *>(Alloc->Start));
|
|
case OL_MEM_INFO_TYPE:
|
|
return Info.write<ol_alloc_type_t>(Alloc->Type);
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetMemInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr);
|
|
}
|
|
|
|
Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
|
|
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
|
|
|
|
auto OutstandingQueue = Device->getOutstandingQueue();
|
|
if (OutstandingQueue) {
|
|
// The queue is empty, but we still need to sync it to release any temporary
|
|
// memory allocations or do other cleanup.
|
|
if (auto Err =
|
|
Device->Device->synchronize(OutstandingQueue, /*Release=*/false))
|
|
return Err;
|
|
CreatedQueue->AsyncInfo = OutstandingQueue;
|
|
} else if (auto Err =
|
|
Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) {
|
|
return Err;
|
|
}
|
|
|
|
*Queue = CreatedQueue.release();
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyQueue_impl(ol_queue_handle_t Queue) {
|
|
auto *Device = Queue->Device;
|
|
// This is safe; as soon as olDestroyQueue is called it is not possible to add
|
|
// any more work to the queue, so if it's finished now it will remain finished
|
|
// forever.
|
|
auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
|
|
if (!*Res) {
|
|
// The queue is complete, so sync it and throw it back into the pool.
|
|
if (auto Err = Device->Device->synchronize(Queue->AsyncInfo,
|
|
/*Release=*/true))
|
|
return Err;
|
|
} else {
|
|
// The queue still has outstanding work. Store it so we can check it later.
|
|
std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex);
|
|
Device->OutstandingQueues.push_back(Queue->AsyncInfo);
|
|
}
|
|
|
|
return olDestroy(Queue);
|
|
}
|
|
|
|
Error olSyncQueue_impl(ol_queue_handle_t Queue) {
|
|
// Host plugin doesn't have a queue set so it's not safe to call synchronize
|
|
// on it, but we have nothing to synchronize in that situation anyway.
|
|
if (Queue->AsyncInfo->Queue) {
|
|
// We don't need to release the queue and we would like the ability for
|
|
// other offload threads to submit work concurrently, so pass "false" here
|
|
// so we don't release the underlying queue object.
|
|
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
|
|
size_t NumEvents) {
|
|
auto *Device = Queue->Device->Device;
|
|
|
|
for (size_t I = 0; I < NumEvents; I++) {
|
|
auto *Event = Events[I];
|
|
|
|
if (!Event)
|
|
return Plugin::error(ErrorCode::INVALID_NULL_HANDLE,
|
|
"olWaitEvents asked to wait on a NULL event");
|
|
|
|
// Do nothing if the event is for this queue or the backend does not
|
|
// materialize event state for it.
|
|
if (Event->QueueId == Queue->Id || !Event->EventInfo)
|
|
continue;
|
|
|
|
if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
|
|
return Err;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
|
|
ol_queue_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
switch (PropName) {
|
|
case OL_QUEUE_INFO_DEVICE:
|
|
return Info.write<ol_device_handle_t>(Queue->Device);
|
|
case OL_QUEUE_INFO_EMPTY: {
|
|
auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo);
|
|
if (auto Err = Pending.takeError())
|
|
return Err;
|
|
return Info.write<bool>(!*Pending);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetQueueInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olSyncEvent_impl(ol_event_handle_t Event) {
|
|
// Some backends do not materialize backend event state. Treat such events as
|
|
// trivially complete.
|
|
if (!Event->EventInfo)
|
|
return Plugin::success();
|
|
|
|
if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetEventElapsedTime_impl(ol_event_handle_t StartEvent,
|
|
ol_event_handle_t EndEvent,
|
|
float *ElapsedTime) {
|
|
if (StartEvent->Device != EndEvent->Device)
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_DEVICE,
|
|
"StartEvent and EndEvent must belong to the same device");
|
|
|
|
auto ElapsedTimeOrErr = StartEvent->Device->Device->getEventElapsedTime(
|
|
StartEvent->EventInfo, EndEvent->EventInfo);
|
|
if (!ElapsedTimeOrErr)
|
|
return ElapsedTimeOrErr.takeError();
|
|
|
|
*ElapsedTime = *ElapsedTimeOrErr;
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyEvent_impl(ol_event_handle_t Event) {
|
|
if (Event->EventInfo)
|
|
if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo))
|
|
return Res;
|
|
|
|
return olDestroy(Event);
|
|
}
|
|
|
|
Error olGetEventInfoImplDetail(ol_event_handle_t Event,
|
|
ol_event_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
auto Queue = Event->Queue;
|
|
|
|
switch (PropName) {
|
|
case OL_EVENT_INFO_QUEUE:
|
|
return Info.write<ol_queue_handle_t>(Queue);
|
|
case OL_EVENT_INFO_IS_COMPLETE: {
|
|
// Some backends do not materialize backend event state. Treat such events
|
|
// as trivially complete.
|
|
if (!Event->EventInfo)
|
|
return Info.write<bool>(true);
|
|
|
|
auto Res = Queue->Device->Device->isEventComplete(Event->EventInfo,
|
|
Queue->AsyncInfo);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
return Info.write<bool>(*Res);
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetEventInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName,
|
|
size_t *PropSizeRet) {
|
|
return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
|
|
auto Event = std::make_unique<ol_event_impl_t>(nullptr, Queue->Device, Queue);
|
|
|
|
if (auto Err = Queue->Device->Device->createEvent(&Event->EventInfo))
|
|
return Err;
|
|
|
|
if (auto Err = Queue->Device->Device->recordEvent(Event->EventInfo,
|
|
Queue->AsyncInfo)) {
|
|
if (Event->EventInfo) {
|
|
if (auto DestroyErr =
|
|
Queue->Device->Device->destroyEvent(Event->EventInfo))
|
|
return joinErrors(std::move(Err), std::move(DestroyErr));
|
|
}
|
|
|
|
return Err;
|
|
}
|
|
|
|
*EventOut = Event.release();
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
|
|
ol_device_handle_t DstDevice, const void *SrcPtr,
|
|
ol_device_handle_t SrcDevice, size_t Size) {
|
|
bool IsDstHost = DstDevice->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
bool IsSrcHost = SrcDevice->Platform.BackendType == OL_PLATFORM_BACKEND_HOST;
|
|
|
|
if (IsDstHost && IsSrcHost) {
|
|
if (!Queue) {
|
|
std::memcpy(DstPtr, SrcPtr, Size);
|
|
return Error::success();
|
|
} else {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_ARGUMENT,
|
|
"ane of DstDevice and SrcDevice must be a non-host device if "
|
|
"queue is specified");
|
|
}
|
|
}
|
|
|
|
// If no queue is given the memcpy will be synchronous
|
|
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
|
|
if (IsDstHost) {
|
|
if (auto Res =
|
|
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else if (IsSrcHost) {
|
|
if (auto Res =
|
|
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else if (SrcDevice->Platform.Plugin == DstDevice->Platform.Plugin &&
|
|
SrcDevice->Platform.Plugin->isDataExchangable(
|
|
SrcDevice->Device->getDeviceId(),
|
|
DstDevice->Device->getDeviceId())) {
|
|
if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
|
|
DstPtr, Size, QueueImpl))
|
|
return Res;
|
|
} else {
|
|
if (Queue)
|
|
if (auto Res = olSyncQueue_impl(Queue))
|
|
return Res;
|
|
|
|
void *Buffer = malloc(Size);
|
|
if (!Buffer)
|
|
return createOffloadError(ErrorCode::OUT_OF_RESOURCES,
|
|
"Couldn't allocate a buffer for transfer");
|
|
Error Res = SrcDevice->Device->dataRetrieve(Buffer, SrcPtr, Size, nullptr);
|
|
if (!Res)
|
|
Res = DstDevice->Device->dataSubmit(DstPtr, Buffer, Size, nullptr);
|
|
|
|
free(Buffer);
|
|
return Res;
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
|
|
const void *PatternPtr, size_t FillSize) {
|
|
return Queue->Device->Device->dataFill(Ptr, PatternPtr, PatternSize, FillSize,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, ol_program_handle_t *Program) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
Expected<plugin::DeviceImageTy *> Res =
|
|
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
|
|
if (!Res)
|
|
return Res.takeError();
|
|
assert(*Res && "loadBinary returned nullptr");
|
|
|
|
*Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer());
|
|
return Error::success();
|
|
}
|
|
|
|
Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData,
|
|
size_t ProgDataSize, bool *IsValid) {
|
|
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
|
|
*IsValid = Device->Device ? Device->Device->Plugin.isDeviceCompatible(
|
|
Device->Device->getDeviceId(), Buffer)
|
|
: false;
|
|
return Error::success();
|
|
}
|
|
|
|
Error olDestroyProgram_impl(ol_program_handle_t Program) {
|
|
auto &Device = Program->Image->getDevice();
|
|
if (auto Err = Device.unloadBinary(Program->Image))
|
|
return Err;
|
|
|
|
auto &LoadedImages = Device.LoadedImages;
|
|
LoadedImages.erase(
|
|
std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image));
|
|
|
|
return olDestroy(Program);
|
|
}
|
|
|
|
Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel,
|
|
size_t DynamicMemSize,
|
|
size_t *GroupSize) {
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
|
|
auto Res = KernelImpl->maxGroupSize(*Device->Device, DynamicMemSize);
|
|
if (auto Err = Res.takeError())
|
|
return Err;
|
|
|
|
*GroupSize = *Res;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
|
|
ol_symbol_handle_t Kernel, const void *ArgumentsData,
|
|
size_t ArgumentsSize,
|
|
const ol_kernel_launch_size_args_t *LaunchSizeArgs) {
|
|
auto *DeviceImpl = Device->Device;
|
|
if (Queue && Device != Queue->Device) {
|
|
return createOffloadError(
|
|
ErrorCode::INVALID_DEVICE,
|
|
"device specified does not match the device of the given queue");
|
|
}
|
|
|
|
if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
|
|
return createOffloadError(ErrorCode::SYMBOL_KIND,
|
|
"provided symbol is not a kernel");
|
|
|
|
auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
|
|
AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
|
|
KernelArgsTy LaunchArgs{};
|
|
LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x;
|
|
LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y;
|
|
LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z;
|
|
LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x;
|
|
LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y;
|
|
LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
|
|
LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
|
|
|
|
KernelLaunchParamsTy Params;
|
|
Params.Data = const_cast<void *>(ArgumentsData);
|
|
Params.Size = ArgumentsSize;
|
|
LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
|
|
// Don't do anything with pointer indirection; use arg data as-is
|
|
LaunchArgs.Flags.IsCUDA = true;
|
|
|
|
auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
|
|
auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
|
|
LaunchArgs, AsyncInfoWrapper);
|
|
|
|
AsyncInfoWrapper.finalize(Err);
|
|
if (Err)
|
|
return Err;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
|
|
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
|
|
auto &Device = Program->Image->getDevice();
|
|
|
|
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
|
|
|
|
switch (Kind) {
|
|
case OL_SYMBOL_KIND_KERNEL: {
|
|
auto &Kernel = Program->KernelSymbols[Name];
|
|
if (!Kernel) {
|
|
auto KernelImpl = Device.constructKernel(Name);
|
|
if (!KernelImpl)
|
|
return KernelImpl.takeError();
|
|
|
|
if (auto Err = KernelImpl->init(Device, *Program->Image))
|
|
return Err;
|
|
|
|
Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(),
|
|
&*KernelImpl);
|
|
}
|
|
|
|
*Symbol = Kernel.get();
|
|
return Error::success();
|
|
}
|
|
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
|
|
auto &Global = Program->GlobalSymbols[Name];
|
|
if (!Global) {
|
|
GlobalTy GlobalObj{Name};
|
|
if (auto Res =
|
|
Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
|
|
Device, *Program->Image, GlobalObj))
|
|
return Res;
|
|
|
|
Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(),
|
|
std::move(GlobalObj));
|
|
}
|
|
|
|
*Symbol = Global.get();
|
|
return Error::success();
|
|
}
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"getSymbol kind enum '%i' is invalid", Kind);
|
|
}
|
|
}
|
|
|
|
Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t PropSize,
|
|
void *PropValue, size_t *PropSizeRet) {
|
|
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
|
|
|
auto CheckKind = [&](ol_symbol_kind_t Required) {
|
|
if (Symbol->Kind != Required) {
|
|
std::string ErrBuffer;
|
|
raw_string_ostream(ErrBuffer)
|
|
<< PropName << ": Expected a symbol of Kind " << Required
|
|
<< " but given a symbol of Kind " << Symbol->Kind;
|
|
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
|
|
}
|
|
return Plugin::success();
|
|
};
|
|
|
|
switch (PropName) {
|
|
case OL_SYMBOL_INFO_KIND:
|
|
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
|
|
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
|
|
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
|
return Err;
|
|
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
|
|
default:
|
|
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
|
"olGetSymbolInfo enum '%i' is invalid", PropName);
|
|
}
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName,
|
|
size_t PropSize, void *PropValue) {
|
|
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue,
|
|
nullptr);
|
|
}
|
|
|
|
Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
|
|
ol_symbol_info_t PropName, size_t *PropSizeRet) {
|
|
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
|
|
}
|
|
|
|
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
|
|
ol_host_function_cb_t Callback,
|
|
void *UserData) {
|
|
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
|
|
Queue->AsyncInfo);
|
|
}
|
|
|
|
Error olMemRegister_impl(ol_device_handle_t Device, void *Ptr, size_t Size,
|
|
ol_memory_register_flags_t Flags, void **LockedPtr) {
|
|
Expected<void *> LockedPtrOrErr = Device->Device->registerMemory(
|
|
Ptr, Size, Flags & OL_MEMORY_REGISTER_FLAG_LOCK_MEMORY);
|
|
if (!LockedPtrOrErr)
|
|
return LockedPtrOrErr.takeError();
|
|
|
|
*LockedPtr = *LockedPtrOrErr;
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
Error olMemUnregister_impl(ol_device_handle_t Device, void *Ptr,
|
|
ol_memory_register_flags_t Flags) {
|
|
return Device->Device->unregisterMemory(
|
|
Ptr, Flags & OL_MEMORY_REGISTER_FLAG_UNLOCK_MEMORY);
|
|
}
|
|
|
|
Error olQueryQueue_impl(ol_queue_handle_t Queue, bool *IsQueueWorkCompleted) {
|
|
if (Queue->AsyncInfo->Queue) {
|
|
if (auto Err = Queue->Device->Device->queryAsync(Queue->AsyncInfo, false,
|
|
IsQueueWorkCompleted))
|
|
return Err;
|
|
} else if (IsQueueWorkCompleted) {
|
|
// No underlying queue means there's no work to complete.
|
|
*IsQueueWorkCompleted = true;
|
|
}
|
|
return Error::success();
|
|
}
|
|
|
|
} // namespace offload
|
|
} // namespace llvm
|