diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index b47223612479..9bdd4291e096 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -139,14 +139,30 @@ def ol_dimensions_t : Struct { ]; } +def ol_init_args_t : Struct { + let desc = "Configuration arguments for olInit."; + let members = [ + StructMember<"size_t", "Size", "Size of this struct, used for ABI compatibility. Must be set to sizeof(ol_init_args_t) by the caller.">, + StructMember<"uint32_t", "NumPlatforms", "Number of entries in the Platforms array.">, + StructMember<"const ol_platform_backend_t*", "Platforms", "Pointer to an array of platform backends to initialize."> + ]; +} + +def OL_INIT_ARGS_INIT : Macro { + let desc = "Default initializer for ol_init_args_t. Sets Size to the correct value and all other fields to zero/NULL."; + let value = "{ sizeof(ol_init_args_t), 0, NULL }"; +} + def olInit : Function { let desc = "Perform initialization of the Offload library"; let details = [ "This must be the first API call made by a user of the Offload library", - "The underlying platforms are lazily initialized on their first use" - "Each call will increment an internal reference count that is decremented by `olShutDown`" + "Each call will increment an internal reference count that is decremented by `olShutDown`", + "If InitArgs is NULL, default configuration is used which initializes all available platforms" + ]; + let params = [ + Param<"const ol_init_args_t*", "InitArgs", "Optional pointer to initialization configuration. NULL uses defaults.", PARAM_IN_OPTIONAL> ]; - let params = []; let returns = []; } diff --git a/offload/liboffload/README.md b/offload/liboffload/README.md index 267d27172828..b65dace632f1 100644 --- a/offload/liboffload/README.md +++ b/offload/liboffload/README.md @@ -21,7 +21,7 @@ environment variable. This works with any program that uses liboffload. ```sh $ OFFLOAD_TRACE=1 ./offload.unittests ----> olInit()-> OL_SUCCESS +---> olInit(nullptr)-> OL_SUCCESS # etc ``` diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index db8a5f3a7825..939664ff5925 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -275,13 +275,22 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" -Error initPlugins(OffloadContext &Context) { - // Attempt to create an instance of each supported plugin. +Error initPlugins(OffloadContext &Context, const ol_init_args_t *InitArgs) { + SmallSet Requested; + if (InitArgs && InitArgs->NumPlatforms > 0) + for (uint32_t I = 0; I < InitArgs->NumPlatforms; I++) + Requested.insert(InitArgs->Platforms[I]); + + // Attempt to create an instance of each supported plugin, skipping + // unrequested backends. The host plugin is always created. #define PLUGIN_TARGET(Name) \ do { \ - Context.Platforms.emplace_back(std::make_unique( \ - std::unique_ptr(createPlugin_##Name()), \ - pluginNameToBackend(#Name))); \ + auto Backend = pluginNameToBackend(#Name); \ + if (Requested.empty() || Backend == OL_PLATFORM_BACKEND_HOST || \ + Requested.contains(Backend)) { \ + Context.Platforms.emplace_back(std::make_unique( \ + std::unique_ptr(createPlugin_##Name()), Backend)); \ + } \ } while (false); #include "Shared/Targets.def" @@ -299,7 +308,7 @@ Error initPlugins(OffloadContext &Context) { return Plugin::success(); } -Error olInit_impl() { +Error olInit_impl(const ol_init_args_t *InitArgs) { std::lock_guard Lock(OffloadContextValMutex); if (isOffloadInitialized()) { @@ -307,10 +316,19 @@ Error olInit_impl() { return Plugin::success(); } + if (InitArgs) { + if (InitArgs->Size < sizeof(ol_init_args_t)) + return createOffloadError(ErrorCode::INVALID_SIZE, + "ol_init_args_t Size field is too small"); + if (InitArgs->NumPlatforms > 0 && !InitArgs->Platforms) + return createOffloadError(ErrorCode::INVALID_NULL_POINTER, + "NumPlatforms > 0 but Platforms is null"); + } + // Use a temporary to ensure that entry points querying OffloadContextVal do // not get a partially initialized context auto *NewContext = new OffloadContext{}; - Error InitResult = initPlugins(*NewContext); + Error InitResult = initPlugins(*NewContext, InitArgs); OffloadContextVal.store(NewContext); OffloadContext::get().RefCount++; diff --git a/offload/tools/deviceinfo/llvm-offload-device-info.cpp b/offload/tools/deviceinfo/llvm-offload-device-info.cpp index 74af3bfb1330..1fc27f90d17e 100644 --- a/offload/tools/deviceinfo/llvm-offload-device-info.cpp +++ b/offload/tools/deviceinfo/llvm-offload-device-info.cpp @@ -246,7 +246,7 @@ ol_result_t printDevice(std::ostream &S, ol_device_handle_t D) { } ol_result_t printRoot(std::ostream &S) { - OFFLOAD_ERR(olInit()); + OFFLOAD_ERR(olInit(nullptr)); S << "Liboffload Version: " << OL_VERSION_MAJOR << "." << OL_VERSION_MINOR << "." << OL_VERSION_PATCH << "\n"; diff --git a/offload/unittests/Conformance/lib/DeviceContext.cpp b/offload/unittests/Conformance/lib/DeviceContext.cpp index 6c3425f1e17c..6e6c2738db51 100644 --- a/offload/unittests/Conformance/lib/DeviceContext.cpp +++ b/offload/unittests/Conformance/lib/DeviceContext.cpp @@ -48,7 +48,7 @@ namespace { // The static 'Wrapper' instance ensures olInit() is called once at program // startup and olShutDown() is called once at program termination struct OffloadInitWrapper { - OffloadInitWrapper() { OL_CHECK(olInit()); } + OffloadInitWrapper() { OL_CHECK(olInit(nullptr)); } ~OffloadInitWrapper() { OL_CHECK(olShutDown()); } }; static OffloadInitWrapper Wrapper{}; diff --git a/offload/unittests/OffloadAPI/common/Environment.cpp b/offload/unittests/OffloadAPI/common/Environment.cpp index f8dcef6215bd..ce06e6ff38a7 100644 --- a/offload/unittests/OffloadAPI/common/Environment.cpp +++ b/offload/unittests/OffloadAPI/common/Environment.cpp @@ -19,7 +19,7 @@ using namespace llvm; // test, while having sensible lifetime for the platform environment #ifndef DISABLE_WRAPPER struct OffloadInitWrapper { - OffloadInitWrapper() { olInit(); } + OffloadInitWrapper() { olInit(nullptr); } ~OffloadInitWrapper() { olShutDown(); } }; static OffloadInitWrapper Wrapper{}; diff --git a/offload/unittests/OffloadAPI/init/olInit.cpp b/offload/unittests/OffloadAPI/init/olInit.cpp index 508615152b4f..4c74122e89d4 100644 --- a/offload/unittests/OffloadAPI/init/olInit.cpp +++ b/offload/unittests/OffloadAPI/init/olInit.cpp @@ -16,7 +16,7 @@ struct olInitTest : ::testing::Test {}; TEST_F(olInitTest, Success) { - ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olInit(nullptr)); ASSERT_SUCCESS(olShutDown()); } @@ -28,7 +28,22 @@ TEST_F(olInitTest, Uninitialized) { TEST_F(olInitTest, RepeatedInit) { for (size_t I = 0; I < 10; I++) { - ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olInit(nullptr)); ASSERT_SUCCESS(olShutDown()); } } + +TEST_F(olInitTest, WithInitArgs) { + ol_init_args_t Args = OL_INIT_ARGS_INIT; + ol_platform_backend_t Backends[] = {OL_PLATFORM_BACKEND_HOST}; + Args.NumPlatforms = 1; + Args.Platforms = Backends; + ASSERT_SUCCESS(olInit(&Args)); + ASSERT_SUCCESS(olShutDown()); +} + +TEST_F(olInitTest, InvalidSize) { + ol_init_args_t Args = OL_INIT_ARGS_INIT; + Args.Size = 0; + ASSERT_ERROR(OL_ERRC_INVALID_SIZE, olInit(&Args)); +}