[Offload] Move (most) global state to an OffloadContext struct (#144494)

Rather than having a number of static local variables, we now use
a single `OffloadContext` struct to store global state. This is
initialised by `olInit`, but is never deleted (de-initialization of
Offload isn't yet implemented).

The error reporting mechanism has not been moved to the struct, since
that's going to cause issues with teardown (error messages must outlive
liboffload).
This commit is contained in:
Ross Brunton 2025-06-19 22:02:03 +01:00 committed by GitHub
parent 9fd22cb56d
commit 53336ad488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 56 deletions

View File

@ -22,12 +22,12 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Error.h"
struct OffloadConfig {
bool TracingEnabled = false;
bool ValidationEnabled = true;
};
OffloadConfig &offloadConfig();
namespace llvm {
namespace offload {
bool isTracingEnabled();
bool isValidationEnabled();
} // namespace offload
} // namespace llvm
// Use the StringSet container to efficiently deduplicate repeated error
// strings (e.g. if the same error is hit constantly in a long running program)

View File

@ -93,22 +93,36 @@ struct AllocInfo {
ol_alloc_type_t Type;
};
using AllocInfoMapT = DenseMap<void *, AllocInfo>;
AllocInfoMapT &allocInfoMap() {
static AllocInfoMapT AllocInfoMap{};
return AllocInfoMap;
}
// Global shared state for liboffload
struct OffloadContext;
static OffloadContext *OffloadContextVal;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
OffloadContext &operator=(OffloadContext &) = delete;
OffloadContext &operator=(OffloadContext &&) = delete;
using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
PlatformVecT &Platforms() {
static PlatformVecT Platforms;
return Platforms;
}
bool TracingEnabled = false;
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return &Platforms().back().Devices[0];
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return &Platforms.back().Devices[0];
}
static OffloadContext &get() {
assert(OffloadContextVal);
return *OffloadContextVal;
}
};
// If the context is uninited, then we assume tracing is disabled
bool isTracingEnabled() {
return OffloadContextVal && OffloadContext::get().TracingEnabled;
}
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;
@ -130,10 +144,12 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#include "Shared/Targets.def"
void initPlugins() {
auto *Context = new OffloadContext{};
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
Platforms().emplace_back(ol_platform_impl_t{ \
Context->Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
{}, \
pluginNameToBackend(#Name)}); \
@ -141,7 +157,7 @@ void initPlugins() {
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
for (auto &Platform : Platforms()) {
for (auto &Platform : Context->Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
@ -157,15 +173,16 @@ void initPlugins() {
}
// Add the special host device
auto &HostPlatform = Platforms().emplace_back(
auto &HostPlatform = Context->Platforms.emplace_back(
ol_platform_impl_t{nullptr,
{ol_device_impl_t{-1, nullptr, nullptr}},
OL_PLATFORM_BACKEND_HOST});
HostDevice()->Platform = &HostPlatform;
Context->HostDevice()->Platform = &HostPlatform;
offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
offloadConfig().ValidationEnabled =
!std::getenv("OFFLOAD_DISABLE_VALIDATION");
Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
OffloadContextVal = Context;
}
// TODO: We can properly reference count here and manage the resources in a more
@ -229,7 +246,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
// Find the info if it exists under any of the given names
auto GetInfo = [&](std::vector<std::string> Names) {
if (Device == HostDevice())
if (Device == OffloadContext::get().HostDevice())
return std::string("Host");
if (!Device->Device)
@ -251,8 +268,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
case OL_DEVICE_INFO_PLATFORM:
return ReturnValue(Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
return Device == OffloadContext::get().HostDevice()
? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
return ReturnValue(GetInfo({"Device Name"}).c_str());
case OL_DEVICE_INFO_VENDOR:
@ -280,7 +298,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
}
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : Platforms()) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform.Devices) {
if (!Callback(&Device, UserData)) {
break;
@ -311,16 +329,17 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
return Alloc.takeError();
*AllocationOut = *Alloc;
allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc,
AllocInfo{Device, Type});
return Error::success();
}
Error olMemFree_impl(void *Address) {
if (!allocInfoMap().contains(Address))
if (!OffloadContext::get().AllocInfoMap.contains(Address))
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
"address is not a known allocation");
auto AllocInfo = allocInfoMap().at(Address);
auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
auto Device = AllocInfo.Device;
auto Type = AllocInfo.Type;
@ -328,7 +347,7 @@ Error olMemFree_impl(void *Address) {
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
return Res;
allocInfoMap().erase(Address);
OffloadContext::get().AllocInfoMap.erase(Address);
return Error::success();
}
@ -395,7 +414,8 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
auto Host = OffloadContext::get().HostDevice();
if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
return Error::success();
@ -410,11 +430,11 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
// If no queue is given the memcpy will be synchronous
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
if (DstDevice == HostDevice()) {
if (DstDevice == Host) {
if (auto Res =
SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
return Res;
} else if (SrcDevice == HostDevice()) {
} else if (SrcDevice == Host) {
if (auto Res =
DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
return Res;

View File

@ -30,11 +30,6 @@ ol_code_location_t *&currentCodeLocation() {
return CodeLoc;
}
OffloadConfig &offloadConfig() {
static OffloadConfig Config{};
return Config;
}
namespace llvm {
namespace offload {
// Pull in the declarations for the implementation functions. The actual entry

View File

@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
// Emit validation checks
for (const auto &Return : F.getReturns()) {
for (auto &Condition : Return.getConditions()) {
if (Condition.starts_with("`") && Condition.ends_with("`")) {
auto ConditionString = Condition.substr(1, Condition.size() - 2);
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
"\"validation failure: {1}\");\n",
Return.getUnprefixedValue(), ConditionString);
OS << TAB_2 "}\n\n";
bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
return llvm::any_of(R.getConditions(), [](auto &C) {
return C.starts_with("`") && C.ends_with("`");
});
});
if (HasValidation) {
OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
// Emit validation checks
for (const auto &Return : F.getReturns()) {
for (auto &Condition : Return.getConditions()) {
if (Condition.starts_with("`") && Condition.ends_with("`")) {
auto ConditionString = Condition.substr(1, Condition.size() - 2);
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
OS << formatv(TAB_3
"return createOffloadError(error::ErrorCode::{0}, "
"\"validation failure: {1}\");\n",
Return.getUnprefixedValue(), ConditionString);
OS << TAB_2 "}\n\n";
}
}
}
OS << TAB_1 "}\n\n";
}
OS << TAB_1 "}\n\n";
// Perform actual function call to the implementation
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
@ -74,7 +83,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
OS << ") {\n";
// Emit pre-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
OS << TAB_1 "}\n\n";
@ -85,7 +94,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
PrefixLower, F.getName(), ParamNameList);
// Emit post-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
if (F.getParams().size() > 0) {
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
for (const auto &Param : F.getParams()) {