[Offload] Check for initialization (#144370)

All entry points (except olInit) now check that offload has been
initialized. If not, a new `OL_ERRC_UNINITIALIZED` error is returned.
This commit is contained in:
Ross Brunton 2025-06-20 15:04:50 +01:00 committed by GitHub
parent bd36f7331a
commit e0633d59b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 44 additions and 1 deletions

View File

@ -106,6 +106,7 @@ def ErrorCode : Enum {
Etor<"ASSEMBLE_FAILURE", "assembler failure while processing binary image">,
Etor<"LINK_FAILURE", "linker failure while processing binary image">,
Etor<"BACKEND_FAILURE", "the plugin backend is in an invalid or unsupported state">,
Etor<"UNINITIALIZED", "not initialized">,
// Handle related errors - only makes sense for liboffload
Etor<"INVALID_NULL_HANDLE", "a handle argument is null when it should not be">,

View File

@ -26,6 +26,7 @@ namespace llvm {
namespace offload {
bool isTracingEnabled();
bool isValidationEnabled();
bool isOffloadInitialized();
} // namespace offload
} // namespace llvm

View File

@ -120,9 +120,10 @@ struct OffloadContext {
// If the context is uninited, then we assume tracing is disabled
bool isTracingEnabled() {
return OffloadContextVal && OffloadContext::get().TracingEnabled;
return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
}
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;

View File

@ -82,6 +82,10 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
// Check offload is initialized
if (F.getName() != "olInit")
OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;";
// Emit pre-call prints
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
@ -143,6 +147,14 @@ static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
OS << GenericHeader;
constexpr const char *UninitMessage =
"liboffload has not been initialized - please call olInit before using "
"this API";
OS << formatv("static {0}_error_struct_t UninitError = "
"{{{1}_ERRC_UNINITIALIZED, \"{2}\"};",
PrefixLower, PrefixUpper, UninitMessage);
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
EmitValidationFunc(FunctionRec{R}, OS);
EmitEntryPointFunc(FunctionRec{R}, OS);

View File

@ -12,6 +12,10 @@ add_offload_unittest("event"
event/olDestroyEvent.cpp
event/olWaitEvent.cpp)
add_offload_unittest("init"
init/olInit.cpp)
target_compile_definitions("init.unittests" PRIVATE DISABLE_WRAPPER)
add_offload_unittest("kernel"
kernel/olGetKernel.cpp
kernel/olLaunchKernel.cpp)

View File

@ -17,11 +17,13 @@ using namespace llvm;
// Wrapper so we don't have to constantly init and shutdown Offload in every
// test, while having sensible lifetime for the platform environment
#ifndef DISABLE_WRAPPER
struct OffloadInitWrapper {
OffloadInitWrapper() { olInit(); }
~OffloadInitWrapper() { olShutDown(); }
};
static OffloadInitWrapper Wrapper{};
#endif
static cl::opt<std::string>
SelectedPlatform("platform", cl::desc("Only test the specified platform"),

View File

@ -0,0 +1,22 @@
//===------- Offload API tests - olInit -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// NOTE: For this test suite, the implicit olInit/olShutDown doesn't happen, so
// tests have to do it themselves
#include "../common/Fixtures.hpp"
#include <OffloadAPI.h>
#include <gtest/gtest.h>
struct olInitTest : ::testing::Test {};
TEST_F(olInitTest, Uninitialized) {
ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
olIterateDevices(
[](ol_device_handle_t, void *) { return false; }, nullptr));
}