diff --git a/offload/liboffload/API/APIDefs.td b/offload/liboffload/API/APIDefs.td index cee4adea1d9f..640932dcf846 100644 --- a/offload/liboffload/API/APIDefs.td +++ b/offload/liboffload/API/APIDefs.td @@ -199,7 +199,7 @@ class Typedef : APIObject { string value; } class FptrTypedef : APIObject { list params; - list returns; + string return; } class Macro : APIObject { diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index 5b19d1d47129..de7502b54061 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -62,6 +62,27 @@ def : Handle { let desc = "Handle of context object"; } +def : Handle { + let name = "ol_queue_handle_t"; + let desc = "Handle of queue object"; +} + +def : Handle { + let name = "ol_event_handle_t"; + let desc = "Handle of event object"; +} + +def : Handle { + let name = "ol_program_handle_t"; + let desc = "Handle of program object"; +} + +def : Typedef { + let name = "ol_kernel_handle_t"; + let desc = "Handle of kernel object"; + let value = "void *"; +} + def : Enum { let name = "ol_errc_t"; let desc = "Defines Return/Error codes"; @@ -69,12 +90,11 @@ def : Enum { Etor<"SUCCESS", "Success">, Etor<"INVALID_VALUE", "Invalid Value">, Etor<"INVALID_PLATFORM", "Invalid platform">, - Etor<"DEVICE_NOT_FOUND", "Device not found">, Etor<"INVALID_DEVICE", "Invalid device">, - Etor<"DEVICE_LOST", "Device hung, reset, was removed, or driver update occurred">, - Etor<"UNINITIALIZED", "plugin is not initialized or specific entry-point is not implemented">, + Etor<"INVALID_QUEUE", "Invalid queue">, + Etor<"INVALID_EVENT", "Invalid event">, + Etor<"INVALID_KERNEL_NAME", "Named kernel not found in the program binary">, Etor<"OUT_OF_RESOURCES", "Out of resources">, - Etor<"UNSUPPORTED_VERSION", "generic error code for unsupported versions">, Etor<"UNSUPPORTED_FEATURE", "generic error code for unsupported features">, Etor<"INVALID_ARGUMENT", "generic error code for invalid arguments">, Etor<"INVALID_NULL_HANDLE", "handle argument is not valid">, diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td index 30c0b71fe7b3..28c96bb5d291 100644 --- a/offload/liboffload/API/Device.td +++ b/offload/liboffload/API/Device.td @@ -12,7 +12,7 @@ def : Enum { let name = "ol_device_type_t"; - let desc = "Supported device types"; + let desc = "Supported device types."; let etors =[ Etor<"DEFAULT", "The default device type as preferred by the runtime">, Etor<"ALL", "Devices of all types">, @@ -23,7 +23,7 @@ def : Enum { def : Enum { let name = "ol_device_info_t"; - let desc = "Supported device info"; + let desc = "Supported device info."; let is_typed = 1; let etors =[ TaggedEtor<"TYPE", "ol_device_type_t", "type of the device">, @@ -34,39 +34,34 @@ def : Enum { ]; } -def : Function { - let name = "olGetDeviceCount"; - let desc = "Retrieves the number of available devices within a platform"; +def : FptrTypedef { + let name = "ol_device_iterate_cb_t"; + let desc = "User-provided function to be used with `olIterateDevices`"; let params = [ - Param<"ol_platform_handle_t", "Platform", "handle of the platform instance", PARAM_IN>, - Param<"uint32_t*", "NumDevices", "pointer to the number of devices.", PARAM_OUT> + Param<"ol_device_handle_t", "Device", "the device handle of the current iteration", PARAM_IN>, + Param<"void*", "UserData", "optional user data", PARAM_IN_OPTIONAL> ]; - let returns = []; + let return = "bool"; } def : Function { - let name = "olGetDevice"; - let desc = "Retrieves devices within a platform"; + let name = "olIterateDevices"; + let desc = "Iterates over all available devices, calling the callback for each device."; let details = [ - "Multiple calls to this function will return identical device handles, in the same order.", + "If the user-provided callback returns `false`, the iteration is stopped." ]; let params = [ - Param<"ol_platform_handle_t", "Platform", "handle of the platform instance", PARAM_IN>, - Param<"uint32_t", "NumEntries", "the number of devices to be added to phDevices, which must be greater than zero", PARAM_IN>, - RangedParam<"ol_device_handle_t*", "Devices", "Array of device handles. " - "If NumEntries is less than the number of devices available, then this function shall only retrieve that number of devices.", PARAM_OUT, - Range<"0", "NumEntries">> + Param<"ol_device_iterate_cb_t", "Callback", "User-provided function called for each available device", PARAM_IN>, + Param<"void*", "UserData", "Optional user data to pass to the callback", PARAM_IN_OPTIONAL> ]; let returns = [ - Return<"OL_ERRC_INVALID_SIZE", [ - "`NumEntries == 0`" - ]> + Return<"OL_ERRC_INVALID_DEVICE"> ]; } def : Function { let name = "olGetDeviceInfo"; - let desc = "Queries the given property of the device"; + let desc = "Queries the given property of the device."; let details = []; let params = [ Param<"ol_device_handle_t", "Device", "handle of the device instance", PARAM_IN>, @@ -90,7 +85,7 @@ def : Function { def : Function { let name = "olGetDeviceInfoSize"; - let desc = "Returns the storage size of the given device query"; + let desc = "Returns the storage size of the given device query."; let details = []; let params = [ Param<"ol_device_handle_t", "Device", "handle of the device instance", PARAM_IN>, diff --git a/offload/liboffload/API/Event.td b/offload/liboffload/API/Event.td new file mode 100644 index 000000000000..c9f79159cf26 --- /dev/null +++ b/offload/liboffload/API/Event.td @@ -0,0 +1,31 @@ +//===-- Event.td - Event definitions for Offload -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to the event handle +// +//===----------------------------------------------------------------------===// + +def : Function { + let name = "olDestroyEvent"; + let desc = "Destroy the event and free all underlying resources."; + let details = []; + let params = [ + Param<"ol_event_handle_t", "Event", "handle of the event", PARAM_IN> + ]; + let returns = []; +} + +def : Function { + let name = "olWaitEvent"; + let desc = "Wait for the event to be complete."; + let details = []; + let params = [ + Param<"ol_event_handle_t", "Event", "handle of the event", PARAM_IN> + ]; + let returns = []; +} diff --git a/offload/liboffload/API/Kernel.td b/offload/liboffload/API/Kernel.td new file mode 100644 index 000000000000..247f9c1bf5b6 --- /dev/null +++ b/offload/liboffload/API/Kernel.td @@ -0,0 +1,61 @@ +//===-- Kernel.td - Kernel definitions for Offload ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to the kernel handle +// +//===----------------------------------------------------------------------===// + +def : Function { + let name = "olGetKernel"; + let desc = "Get a kernel from the function identified by `KernelName` in the given program."; + let details = [ + "The kernel handle returned is owned by the device so does not need to be destroyed." + ]; + let params = [ + Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>, + Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>, + Param<"ol_kernel_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT> + ]; + let returns = []; +} + +def : Struct { + let name = "ol_kernel_launch_size_args_t"; + let desc = "Size-related arguments for a kernel launch."; + let members = [ + StructMember<"size_t", "Dimensions", "Number of work dimensions">, + StructMember<"size_t", "NumGroupsX", "Number of work groups on the X dimension">, + StructMember<"size_t", "NumGroupsY", "Number of work groups on the Y dimension">, + StructMember<"size_t", "NumGroupsZ", "Number of work groups on the Z dimension">, + StructMember<"size_t", "GroupSizeX", "Size of a work group on the X dimension.">, + StructMember<"size_t", "GroupSizeY", "Size of a work group on the Y dimension.">, + StructMember<"size_t", "GroupSizeZ", "Size of a work group on the Z dimension.">, + StructMember<"size_t", "DynSharedMemory", "Size of dynamic shared memory in bytes."> + ]; +} + +def : Function { + let name = "olLaunchKernel"; + let desc = "Enqueue a kernel launch with the specified size and parameters."; + let details = [ + "If a queue is not specified, kernel execution happens synchronously" + ]; + let params = [ + Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN_OPTIONAL>, + Param<"ol_device_handle_t", "Device", "handle of the device to execute on", PARAM_IN>, + Param<"ol_kernel_handle_t", "Kernel", "handle of the kernel", PARAM_IN>, + Param<"const void*", "ArgumentsData", "pointer to the kernel argument struct", PARAM_IN>, + Param<"size_t", "ArgumentsSize", "size of the kernel argument struct", PARAM_IN>, + Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs", "pointer to the struct containing launch size parameters", PARAM_IN>, + Param<"ol_event_handle_t*", "EventOut", "optional recorded event for the enqueued operation", PARAM_OUT_OPTIONAL> + ]; + let returns = [ + Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]>, + Return<"OL_ERRC_INVALID_DEVICE", ["If Queue is non-null but does not belong to Device"]>, + ]; +} diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td new file mode 100644 index 000000000000..9cd1ef6362e1 --- /dev/null +++ b/offload/liboffload/API/Memory.td @@ -0,0 +1,68 @@ +//===-- Memory.td - Memory definitions for Offload ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to memory allocations +// +//===----------------------------------------------------------------------===// + +def : Enum { + let name = "ol_alloc_type_t"; + let desc = "Represents the type of allocation made with olMemAlloc."; + let etors = [ + Etor<"HOST", "Host allocation">, + Etor<"DEVICE", "Device allocation">, + Etor<"MANAGED", "Managed allocation"> + ]; +} + +def : Function { + let name = "olMemAlloc"; + let desc = "Creates a memory allocation on the specified device."; + let params = [ + Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>, + Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>, + Param<"size_t", "Size", "size of the allocation in bytes", PARAM_IN>, + Param<"void**", "AllocationOut", "output for the allocated pointer", PARAM_OUT> + ]; + let returns = [ + Return<"OL_ERRC_INVALID_SIZE", [ + "`Size == 0`" + ]> + ]; +} + +def : Function { + let name = "olMemFree"; + let desc = "Frees a memory allocation previously made by olMemAlloc."; + let params = [ + Param<"void*", "Address", "address of the allocation to free", PARAM_IN>, + ]; + let returns = []; +} + +def : Function { + let name = "olMemcpy"; + let desc = "Enqueue a memcpy operation."; + let details = [ + "For host pointers, use the host device belonging to the OL_PLATFORM_BACKEND_HOST platform.", + "If a queue is specified, at least one device must be a non-host device", + "If a queue is not specified, the memcpy happens synchronously" + ]; + let params = [ + Param<"ol_queue_handle_t", "Queue", "handle of the queue.", PARAM_IN_OPTIONAL>, + Param<"void*", "DstPtr", "pointer to copy to", PARAM_IN>, + Param<"ol_device_handle_t", "DstDevice", "device that DstPtr belongs to", PARAM_IN>, + Param<"void*", "SrcPtr", "pointer to copy from", PARAM_IN>, + Param<"ol_device_handle_t", "SrcDevice", "device that SrcPtr belongs to", PARAM_IN>, + Param<"size_t", "Size", "size in bytes of data to copy", PARAM_IN>, + Param<"ol_event_handle_t*", "EventOut", "optional recorded event for the enqueued operation", PARAM_OUT_OPTIONAL> + ]; + let returns = [ + Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]> + ]; +} diff --git a/offload/liboffload/API/OffloadAPI.td b/offload/liboffload/API/OffloadAPI.td index 8a0c3c405812..f9829155b6ce 100644 --- a/offload/liboffload/API/OffloadAPI.td +++ b/offload/liboffload/API/OffloadAPI.td @@ -13,3 +13,8 @@ include "APIDefs.td" include "Common.td" include "Platform.td" include "Device.td" +include "Memory.td" +include "Queue.td" +include "Event.td" +include "Program.td" +include "Kernel.td" diff --git a/offload/liboffload/API/Platform.td b/offload/liboffload/API/Platform.td index 03e70cf96ac9..97c2cc2d0570 100644 --- a/offload/liboffload/API/Platform.td +++ b/offload/liboffload/API/Platform.td @@ -9,44 +9,10 @@ // This file contains Offload API definitions related to the Platform handle // //===----------------------------------------------------------------------===// -def : Function { - let name = "olGetPlatform"; - let desc = "Retrieves all available platforms"; - let details = [ - "Multiple calls to this function will return identical platforms handles, in the same order.", - ]; - let params = [ - Param<"uint32_t", "NumEntries", - "The number of platforms to be added to Platforms. NumEntries must be " - "greater than zero.", - PARAM_IN>, - RangedParam<"ol_platform_handle_t*", "Platforms", - "Array of handle of platforms. If NumEntries is less than the number of " - "platforms available, then olGetPlatform shall only retrieve that " - "number of platforms.", - PARAM_OUT, Range<"0", "NumEntries">> - ]; - let returns = [ - Return<"OL_ERRC_INVALID_SIZE", [ - "`NumEntries == 0`" - ]> - ]; -} - -def : Function { - let name = "olGetPlatformCount"; - let desc = "Retrieves the number of available platforms"; - let params = [ - Param<"uint32_t*", - "NumPlatforms", "returns the total number of platforms available.", - PARAM_OUT> - ]; - let returns = []; -} def : Enum { let name = "ol_platform_info_t"; - let desc = "Supported platform info"; + let desc = "Supported platform info."; let is_typed = 1; let etors = [ TaggedEtor<"NAME", "char[]", "The string denoting name of the platform. The size of the info needs to be dynamically queried.">, @@ -58,17 +24,18 @@ def : Enum { def : Enum { let name = "ol_platform_backend_t"; - let desc = "Identifies the native backend of the platform"; + let desc = "Identifies the native backend of the platform."; let etors =[ Etor<"UNKNOWN", "The backend is not recognized">, Etor<"CUDA", "The backend is CUDA">, Etor<"AMDGPU", "The backend is AMDGPU">, + Etor<"HOST", "The backend is the host">, ]; } def : Function { let name = "olGetPlatformInfo"; - let desc = "Queries the given property of the platform"; + let desc = "Queries the given property of the platform."; let details = [ "`olGetPlatformInfoSize` can be used to query the storage size " "required for the given query." @@ -96,7 +63,7 @@ def : Function { def : Function { let name = "olGetPlatformInfoSize"; - let desc = "Returns the storage size of the given platform query"; + let desc = "Returns the storage size of the given platform query."; let details = []; let params = [ Param<"ol_platform_handle_t", "Platform", "handle of the platform", PARAM_IN>, diff --git a/offload/liboffload/API/Program.td b/offload/liboffload/API/Program.td new file mode 100644 index 000000000000..8c88fe6e21e6 --- /dev/null +++ b/offload/liboffload/API/Program.td @@ -0,0 +1,34 @@ +//===-- Program.td - Program definitions for Offload -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to the program handle +// +//===----------------------------------------------------------------------===// + +def : Function { + let name = "olCreateProgram"; + let desc = "Create a program for the device from the binary image pointed to by `ProgData`."; + let details = []; + let params = [ + Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>, + Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>, + Param<"size_t", "ProgDataSize", "size of the program binary in bytes", PARAM_IN>, + Param<"ol_program_handle_t*", "Program", "output pointer for the created program", PARAM_OUT> + ]; + let returns = []; +} + +def : Function { + let name = "olDestroyProgram"; + let desc = "Destroy the program and free all underlying resources."; + let details = []; + let params = [ + Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN> + ]; + let returns = []; +} diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td new file mode 100644 index 000000000000..b5bb619c5751 --- /dev/null +++ b/offload/liboffload/API/Queue.td @@ -0,0 +1,42 @@ +//===-- Queue.td - Queue definitions for Offload -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Offload API definitions related to the queue handle +// +//===----------------------------------------------------------------------===// + +def : Function { + let name = "olCreateQueue"; + let desc = "Create a queue for the given device."; + let details = []; + let params = [ + Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>, + Param<"ol_queue_handle_t*", "Queue", "output pointer for the created queue", PARAM_OUT> + ]; + let returns = []; +} + +def : Function { + let name = "olDestroyQueue"; + let desc = "Destroy the queue and free all underlying resources."; + let details = []; + let params = [ + Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN> + ]; + let returns = []; +} + +def : Function { + let name = "olWaitQueue"; + let desc = "Wait for the enqueued work on a queue to complete."; + let details = []; + let params = [ + Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN> + ]; + let returns = []; +} diff --git a/offload/liboffload/API/README.md b/offload/liboffload/API/README.md index b59ac2782a2b..fda1ad39fa93 100644 --- a/offload/liboffload/API/README.md +++ b/offload/liboffload/API/README.md @@ -138,8 +138,8 @@ allow more backends to be easily added in future. A new object can be added to the API by adding to one of the existing `.td` files. It is also possible to add a new tablegen file to the API by adding it -to the includes in `OffloadAPI.td`. When the offload target is rebuilt, the -new definition will be included in the generated files. +to the includes in `OffloadAPI.td`. When the `OffloadGenerate` target is +rebuilt, the new definition will be included in the generated files. ### Adding a new entry point @@ -147,4 +147,4 @@ When a new entry point is added (e.g. `offloadDeviceFoo`), the actual entry point is automatically generated, which contains validation and tracing code. It expects an implementation function (`offloadDeviceFoo_impl`) to be defined, which it will call into. The definition of this implementation function should -be added to `src/offload_impl.cpp` +be added to `src/OffloadImpl.cpp` diff --git a/offload/liboffload/include/OffloadImpl.hpp b/offload/liboffload/include/OffloadImpl.hpp index 6d745095f310..ec470a355309 100644 --- a/offload/liboffload/include/OffloadImpl.hpp +++ b/offload/liboffload/include/OffloadImpl.hpp @@ -22,6 +22,7 @@ struct OffloadConfig { bool TracingEnabled = false; + bool ValidationEnabled = true; }; OffloadConfig &offloadConfig(); diff --git a/offload/liboffload/include/generated/OffloadAPI.h b/offload/liboffload/include/generated/OffloadAPI.h index 11fcc96625ab..ace31c57cf2f 100644 --- a/offload/liboffload/include/generated/OffloadAPI.h +++ b/offload/liboffload/include/generated/OffloadAPI.h @@ -75,15 +75,31 @@ extern "C" { /////////////////////////////////////////////////////////////////////////////// /// @brief Handle of a platform instance -typedef struct ol_platform_handle_t_ *ol_platform_handle_t; +typedef struct ol_platform_impl_t *ol_platform_handle_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Handle of platform's device object -typedef struct ol_device_handle_t_ *ol_device_handle_t; +typedef struct ol_device_impl_t *ol_device_handle_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Handle of context object -typedef struct ol_context_handle_t_ *ol_context_handle_t; +typedef struct ol_context_impl_t *ol_context_handle_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of queue object +typedef struct ol_queue_impl_t *ol_queue_handle_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of event object +typedef struct ol_event_impl_t *ol_event_handle_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of program object +typedef struct ol_program_impl_t *ol_program_handle_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of kernel object +typedef void *ol_kernel_handle_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Defines Return/Error codes @@ -94,34 +110,32 @@ typedef enum ol_errc_t { OL_ERRC_INVALID_VALUE = 1, /// Invalid platform OL_ERRC_INVALID_PLATFORM = 2, - /// Device not found - OL_ERRC_DEVICE_NOT_FOUND = 3, /// Invalid device - OL_ERRC_INVALID_DEVICE = 4, - /// Device hung, reset, was removed, or driver update occurred - OL_ERRC_DEVICE_LOST = 5, - /// plugin is not initialized or specific entry-point is not implemented - OL_ERRC_UNINITIALIZED = 6, + OL_ERRC_INVALID_DEVICE = 3, + /// Invalid queue + OL_ERRC_INVALID_QUEUE = 4, + /// Invalid event + OL_ERRC_INVALID_EVENT = 5, + /// Named kernel not found in the program binary + OL_ERRC_INVALID_KERNEL_NAME = 6, /// Out of resources OL_ERRC_OUT_OF_RESOURCES = 7, - /// generic error code for unsupported versions - OL_ERRC_UNSUPPORTED_VERSION = 8, /// generic error code for unsupported features - OL_ERRC_UNSUPPORTED_FEATURE = 9, + OL_ERRC_UNSUPPORTED_FEATURE = 8, /// generic error code for invalid arguments - OL_ERRC_INVALID_ARGUMENT = 10, + OL_ERRC_INVALID_ARGUMENT = 9, /// handle argument is not valid - OL_ERRC_INVALID_NULL_HANDLE = 11, + OL_ERRC_INVALID_NULL_HANDLE = 10, /// pointer argument may not be nullptr - OL_ERRC_INVALID_NULL_POINTER = 12, + OL_ERRC_INVALID_NULL_POINTER = 11, /// invalid size or dimensions (e.g., must not be zero, or is out of bounds) - OL_ERRC_INVALID_SIZE = 13, + OL_ERRC_INVALID_SIZE = 12, /// enumerator argument is not valid - OL_ERRC_INVALID_ENUMERATION = 14, + OL_ERRC_INVALID_ENUMERATION = 13, /// enumerator argument is not supported by the device - OL_ERRC_UNSUPPORTED_ENUMERATION = 15, + OL_ERRC_UNSUPPORTED_ENUMERATION = 14, /// Unknown or internal error - OL_ERRC_UNKNOWN = 16, + OL_ERRC_UNKNOWN = 15, /// @cond OL_ERRC_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -188,48 +202,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olInit(); OL_APIEXPORT ol_result_t OL_APICALL olShutDown(); /////////////////////////////////////////////////////////////////////////////// -/// @brief Retrieves all available platforms -/// -/// @details -/// - Multiple calls to this function will return identical platforms -/// handles, in the same order. -/// -/// @returns -/// - ::OL_RESULT_SUCCESS -/// - ::OL_ERRC_UNINITIALIZED -/// - ::OL_ERRC_DEVICE_LOST -/// - ::OL_ERRC_INVALID_SIZE -/// + `NumEntries == 0` -/// - ::OL_ERRC_INVALID_NULL_HANDLE -/// - ::OL_ERRC_INVALID_NULL_POINTER -/// + `NULL == Platforms` -OL_APIEXPORT ol_result_t OL_APICALL olGetPlatform( - // [in] The number of platforms to be added to Platforms. NumEntries must be - // greater than zero. - uint32_t NumEntries, - // [out] Array of handle of platforms. If NumEntries is less than the number - // of platforms available, then olGetPlatform shall only retrieve that - // number of platforms. - ol_platform_handle_t *Platforms); - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Retrieves the number of available platforms -/// -/// @details -/// -/// @returns -/// - ::OL_RESULT_SUCCESS -/// - ::OL_ERRC_UNINITIALIZED -/// - ::OL_ERRC_DEVICE_LOST -/// - ::OL_ERRC_INVALID_NULL_HANDLE -/// - ::OL_ERRC_INVALID_NULL_POINTER -/// + `NULL == NumPlatforms` -OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCount( - // [out] returns the total number of platforms available. - uint32_t *NumPlatforms); - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Supported platform info +/// @brief Supported platform info. typedef enum ol_platform_info_t { /// [char[]] The string denoting name of the platform. The size of the info /// needs to be dynamically queried. @@ -249,7 +222,7 @@ typedef enum ol_platform_info_t { } ol_platform_info_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Identifies the native backend of the platform +/// @brief Identifies the native backend of the platform. typedef enum ol_platform_backend_t { /// The backend is not recognized OL_PLATFORM_BACKEND_UNKNOWN = 0, @@ -257,6 +230,8 @@ typedef enum ol_platform_backend_t { OL_PLATFORM_BACKEND_CUDA = 1, /// The backend is AMDGPU OL_PLATFORM_BACKEND_AMDGPU = 2, + /// The backend is the host + OL_PLATFORM_BACKEND_HOST = 3, /// @cond OL_PLATFORM_BACKEND_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -264,7 +239,7 @@ typedef enum ol_platform_backend_t { } ol_platform_backend_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Queries the given property of the platform +/// @brief Queries the given property of the platform. /// /// @details /// - `olGetPlatformInfoSize` can be used to query the storage size required @@ -298,7 +273,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfo( void *PropValue); /////////////////////////////////////////////////////////////////////////////// -/// @brief Returns the storage size of the given platform query +/// @brief Returns the storage size of the given platform query. /// /// @details /// @@ -322,7 +297,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfoSize( size_t *PropSizeRet); /////////////////////////////////////////////////////////////////////////////// -/// @brief Supported device types +/// @brief Supported device types. typedef enum ol_device_type_t { /// The default device type as preferred by the runtime OL_DEVICE_TYPE_DEFAULT = 0, @@ -339,7 +314,7 @@ typedef enum ol_device_type_t { } ol_device_type_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Supported device info +/// @brief Supported device info. typedef enum ol_device_info_t { /// [ol_device_type_t] type of the device OL_DEVICE_INFO_TYPE = 0, @@ -358,54 +333,36 @@ typedef enum ol_device_info_t { } ol_device_info_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Retrieves the number of available devices within a platform +/// @brief User-provided function to be used with `olIterateDevices` +typedef bool (*ol_device_iterate_cb_t)( + // the device handle of the current iteration + ol_device_handle_t Device, + // optional user data + void *UserData); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Iterates over all available devices, calling the callback for each +/// device. /// /// @details +/// - If the user-provided callback returns `false`, the iteration is +/// stopped. /// /// @returns /// - ::OL_RESULT_SUCCESS /// - ::OL_ERRC_UNINITIALIZED /// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_DEVICE /// - ::OL_ERRC_INVALID_NULL_HANDLE -/// + `NULL == Platform` /// - ::OL_ERRC_INVALID_NULL_POINTER -/// + `NULL == NumDevices` -OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceCount( - // [in] handle of the platform instance - ol_platform_handle_t Platform, - // [out] pointer to the number of devices. - uint32_t *NumDevices); +OL_APIEXPORT ol_result_t OL_APICALL olIterateDevices( + // [in] User-provided function called for each available device + ol_device_iterate_cb_t Callback, + // [in][optional] Optional user data to pass to the callback + void *UserData); /////////////////////////////////////////////////////////////////////////////// -/// @brief Retrieves devices within a platform -/// -/// @details -/// - Multiple calls to this function will return identical device handles, -/// in the same order. -/// -/// @returns -/// - ::OL_RESULT_SUCCESS -/// - ::OL_ERRC_UNINITIALIZED -/// - ::OL_ERRC_DEVICE_LOST -/// - ::OL_ERRC_INVALID_SIZE -/// + `NumEntries == 0` -/// - ::OL_ERRC_INVALID_NULL_HANDLE -/// + `NULL == Platform` -/// - ::OL_ERRC_INVALID_NULL_POINTER -/// + `NULL == Devices` -OL_APIEXPORT ol_result_t OL_APICALL olGetDevice( - // [in] handle of the platform instance - ol_platform_handle_t Platform, - // [in] the number of devices to be added to phDevices, which must be - // greater than zero - uint32_t NumEntries, - // [out] Array of device handles. If NumEntries is less than the number of - // devices available, then this function shall only retrieve that number of - // devices. - ol_device_handle_t *Devices); - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Queries the given property of the device +/// @brief Queries the given property of the device. /// /// @details /// @@ -437,7 +394,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo( void *PropValue); /////////////////////////////////////////////////////////////////////////////// -/// @brief Returns the storage size of the given device query +/// @brief Returns the storage size of the given device query. /// /// @details /// @@ -461,19 +418,294 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize( size_t *PropSizeRet); /////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for olGetPlatform -/// @details Each entry is a pointer to the parameter passed to the function; -typedef struct ol_get_platform_params_t { - uint32_t *pNumEntries; - ol_platform_handle_t **pPlatforms; -} ol_get_platform_params_t; +/// @brief Represents the type of allocation made with olMemAlloc. +typedef enum ol_alloc_type_t { + /// Host allocation + OL_ALLOC_TYPE_HOST = 0, + /// Device allocation + OL_ALLOC_TYPE_DEVICE = 1, + /// Managed allocation + OL_ALLOC_TYPE_MANAGED = 2, + /// @cond + OL_ALLOC_TYPE_FORCE_UINT32 = 0x7fffffff + /// @endcond + +} ol_alloc_type_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for olGetPlatformCount -/// @details Each entry is a pointer to the parameter passed to the function; -typedef struct ol_get_platform_count_params_t { - uint32_t **pNumPlatforms; -} ol_get_platform_count_params_t; +/// @brief Creates a memory allocation on the specified device. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_SIZE +/// + `Size == 0` +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Device` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == AllocationOut` +OL_APIEXPORT ol_result_t OL_APICALL olMemAlloc( + // [in] handle of the device to allocate on + ol_device_handle_t Device, + // [in] type of the allocation + ol_alloc_type_t Type, + // [in] size of the allocation in bytes + size_t Size, + // [out] output for the allocated pointer + void **AllocationOut); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Frees a memory allocation previously made by olMemAlloc. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == Address` +OL_APIEXPORT ol_result_t OL_APICALL olMemFree( + // [in] address of the allocation to free + void *Address); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Enqueue a memcpy operation. +/// +/// @details +/// - For host pointers, use the device returned by olGetHostDevice +/// - If a queue is specified, at least one device must be a non-host device +/// - If a queue is not specified, the memcpy happens synchronously +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_ARGUMENT +/// + `Queue == NULL && EventOut != NULL` +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == DstDevice` +/// + `NULL == SrcDevice` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == DstPtr` +/// + `NULL == SrcPtr` +OL_APIEXPORT ol_result_t OL_APICALL olMemcpy( + // [in][optional] handle of the queue. + ol_queue_handle_t Queue, + // [in] pointer to copy to + void *DstPtr, + // [in] device that DstPtr belongs to + ol_device_handle_t DstDevice, + // [in] pointer to copy from + void *SrcPtr, + // [in] device that SrcPtr belongs to + ol_device_handle_t SrcDevice, + // [in] size in bytes of data to copy + size_t Size, + // [out][optional] optional recorded event for the enqueued operation + ol_event_handle_t *EventOut); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Create a queue for the given device. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Device` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == Queue` +OL_APIEXPORT ol_result_t OL_APICALL olCreateQueue( + // [in] handle of the device + ol_device_handle_t Device, + // [out] output pointer for the created queue + ol_queue_handle_t *Queue); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Destroy the queue and free all underlying resources. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Queue` +/// - ::OL_ERRC_INVALID_NULL_POINTER +OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueue( + // [in] handle of the queue + ol_queue_handle_t Queue); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Wait for the enqueued work on a queue to complete. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Queue` +/// - ::OL_ERRC_INVALID_NULL_POINTER +OL_APIEXPORT ol_result_t OL_APICALL olWaitQueue( + // [in] handle of the queue + ol_queue_handle_t Queue); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Destroy the event and free all underlying resources. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Event` +/// - ::OL_ERRC_INVALID_NULL_POINTER +OL_APIEXPORT ol_result_t OL_APICALL olDestroyEvent( + // [in] handle of the event + ol_event_handle_t Event); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Wait for the event to be complete. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Event` +/// - ::OL_ERRC_INVALID_NULL_POINTER +OL_APIEXPORT ol_result_t OL_APICALL olWaitEvent( + // [in] handle of the event + ol_event_handle_t Event); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Create a program for the device from the binary image pointed to by +/// `ProgData`. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Device` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == ProgData` +/// + `NULL == Program` +OL_APIEXPORT ol_result_t OL_APICALL olCreateProgram( + // [in] handle of the device + ol_device_handle_t Device, + // [in] pointer to the program binary data + const void *ProgData, + // [in] size of the program binary in bytes + size_t ProgDataSize, + // [out] output pointer for the created program + ol_program_handle_t *Program); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Destroy the program and free all underlying resources. +/// +/// @details +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Program` +/// - ::OL_ERRC_INVALID_NULL_POINTER +OL_APIEXPORT ol_result_t OL_APICALL olDestroyProgram( + // [in] handle of the program + ol_program_handle_t Program); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Get a kernel from the function identified by `KernelName` in the +/// given program. +/// +/// @details +/// - The kernel handle returned is owned by the device so does not need to +/// be destroyed. +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Program` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == KernelName` +/// + `NULL == Kernel` +OL_APIEXPORT ol_result_t OL_APICALL olGetKernel( + // [in] handle of the program + ol_program_handle_t Program, + // [in] name of the kernel entry point in the program + const char *KernelName, + // [out] output pointer for the fetched kernel + ol_kernel_handle_t *Kernel); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Size-related arguments for a kernel launch. +typedef struct ol_kernel_launch_size_args_t { + size_t Dimensions; /// Number of work dimensions + size_t NumGroupsX; /// Number of work groups on the X dimension + size_t NumGroupsY; /// Number of work groups on the Y dimension + size_t NumGroupsZ; /// Number of work groups on the Z dimension + size_t GroupSizeX; /// Size of a work group on the X dimension. + size_t GroupSizeY; /// Size of a work group on the Y dimension. + size_t GroupSizeZ; /// Size of a work group on the Z dimension. + size_t DynSharedMemory; /// Size of dynamic shared memory in bytes. +} ol_kernel_launch_size_args_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Enqueue a kernel launch with the specified size and parameters. +/// +/// @details +/// - If a queue is not specified, kernel execution happens synchronously +/// +/// @returns +/// - ::OL_RESULT_SUCCESS +/// - ::OL_ERRC_UNINITIALIZED +/// - ::OL_ERRC_DEVICE_LOST +/// - ::OL_ERRC_INVALID_ARGUMENT +/// + `Queue == NULL && EventOut != NULL` +/// - ::OL_ERRC_INVALID_DEVICE +/// + If Queue is non-null but does not belong to Device +/// - ::OL_ERRC_INVALID_NULL_HANDLE +/// + `NULL == Device` +/// + `NULL == Kernel` +/// - ::OL_ERRC_INVALID_NULL_POINTER +/// + `NULL == ArgumentsData` +/// + `NULL == LaunchSizeArgs` +OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernel( + // [in][optional] handle of the queue + ol_queue_handle_t Queue, + // [in] handle of the device to execute on + ol_device_handle_t Device, + // [in] handle of the kernel + ol_kernel_handle_t Kernel, + // [in] pointer to the kernel argument struct + const void *ArgumentsData, + // [in] size of the kernel argument struct + size_t ArgumentsSize, + // [in] pointer to the struct containing launch size parameters + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + // [out][optional] optional recorded event for the enqueued operation + ol_event_handle_t *EventOut); /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for olGetPlatformInfo @@ -495,21 +727,12 @@ typedef struct ol_get_platform_info_size_params_t { } ol_get_platform_info_size_params_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for olGetDeviceCount +/// @brief Function parameters for olIterateDevices /// @details Each entry is a pointer to the parameter passed to the function; -typedef struct ol_get_device_count_params_t { - ol_platform_handle_t *pPlatform; - uint32_t **pNumDevices; -} ol_get_device_count_params_t; - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for olGetDevice -/// @details Each entry is a pointer to the parameter passed to the function; -typedef struct ol_get_device_params_t { - ol_platform_handle_t *pPlatform; - uint32_t *pNumEntries; - ol_device_handle_t **pDevices; -} ol_get_device_params_t; +typedef struct ol_iterate_devices_params_t { + ol_device_iterate_cb_t *pCallback; + void **pUserData; +} ol_iterate_devices_params_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for olGetDeviceInfo @@ -530,6 +753,111 @@ typedef struct ol_get_device_info_size_params_t { size_t **pPropSizeRet; } ol_get_device_info_size_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olMemAlloc +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_mem_alloc_params_t { + ol_device_handle_t *pDevice; + ol_alloc_type_t *pType; + size_t *pSize; + void ***pAllocationOut; +} ol_mem_alloc_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olMemFree +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_mem_free_params_t { + void **pAddress; +} ol_mem_free_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olMemcpy +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_memcpy_params_t { + ol_queue_handle_t *pQueue; + void **pDstPtr; + ol_device_handle_t *pDstDevice; + void **pSrcPtr; + ol_device_handle_t *pSrcDevice; + size_t *pSize; + ol_event_handle_t **pEventOut; +} ol_memcpy_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olCreateQueue +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_create_queue_params_t { + ol_device_handle_t *pDevice; + ol_queue_handle_t **pQueue; +} ol_create_queue_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olDestroyQueue +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_destroy_queue_params_t { + ol_queue_handle_t *pQueue; +} ol_destroy_queue_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olWaitQueue +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_wait_queue_params_t { + ol_queue_handle_t *pQueue; +} ol_wait_queue_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olDestroyEvent +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_destroy_event_params_t { + ol_event_handle_t *pEvent; +} ol_destroy_event_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olWaitEvent +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_wait_event_params_t { + ol_event_handle_t *pEvent; +} ol_wait_event_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olCreateProgram +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_create_program_params_t { + ol_device_handle_t *pDevice; + const void **pProgData; + size_t *pProgDataSize; + ol_program_handle_t **pProgram; +} ol_create_program_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olDestroyProgram +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_destroy_program_params_t { + ol_program_handle_t *pProgram; +} ol_destroy_program_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olGetKernel +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_get_kernel_params_t { + ol_program_handle_t *pProgram; + const char **pKernelName; + ol_kernel_handle_t **pKernel; +} ol_get_kernel_params_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for olLaunchKernel +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct ol_launch_kernel_params_t { + ol_queue_handle_t *pQueue; + ol_device_handle_t *pDevice; + ol_kernel_handle_t *pKernel; + const void **pArgumentsData; + size_t *pArgumentsSize; + const ol_kernel_launch_size_args_t **pLaunchSizeArgs; + ol_event_handle_t **pEventOut; +} ol_launch_kernel_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Variant of olInit that also sets source code location information /// @details See also ::olInit @@ -542,21 +870,6 @@ olInitWithCodeLoc(ol_code_location_t *CodeLocation); OL_APIEXPORT ol_result_t OL_APICALL olShutDownWithCodeLoc(ol_code_location_t *CodeLocation); -/////////////////////////////////////////////////////////////////////////////// -/// @brief Variant of olGetPlatform that also sets source code location -/// information -/// @details See also ::olGetPlatform -OL_APIEXPORT ol_result_t OL_APICALL -olGetPlatformWithCodeLoc(uint32_t NumEntries, ol_platform_handle_t *Platforms, - ol_code_location_t *CodeLocation); - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Variant of olGetPlatformCount that also sets source code location -/// information -/// @details See also ::olGetPlatformCount -OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCountWithCodeLoc( - uint32_t *NumPlatforms, ol_code_location_t *CodeLocation); - /////////////////////////////////////////////////////////////////////////////// /// @brief Variant of olGetPlatformInfo that also sets source code location /// information @@ -574,21 +887,13 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfoSizeWithCodeLoc( size_t *PropSizeRet, ol_code_location_t *CodeLocation); /////////////////////////////////////////////////////////////////////////////// -/// @brief Variant of olGetDeviceCount that also sets source code location +/// @brief Variant of olIterateDevices that also sets source code location /// information -/// @details See also ::olGetDeviceCount +/// @details See also ::olIterateDevices OL_APIEXPORT ol_result_t OL_APICALL -olGetDeviceCountWithCodeLoc(ol_platform_handle_t Platform, uint32_t *NumDevices, +olIterateDevicesWithCodeLoc(ol_device_iterate_cb_t Callback, void *UserData, ol_code_location_t *CodeLocation); -/////////////////////////////////////////////////////////////////////////////// -/// @brief Variant of olGetDevice that also sets source code location -/// information -/// @details See also ::olGetDevice -OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceWithCodeLoc( - ol_platform_handle_t Platform, uint32_t NumEntries, - ol_device_handle_t *Devices, ol_code_location_t *CodeLocation); - /////////////////////////////////////////////////////////////////////////////// /// @brief Variant of olGetDeviceInfo that also sets source code location /// information @@ -605,6 +910,96 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSizeWithCodeLoc( ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet, ol_code_location_t *CodeLocation); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olMemAlloc that also sets source code location information +/// @details See also ::olMemAlloc +OL_APIEXPORT ol_result_t OL_APICALL olMemAllocWithCodeLoc( + ol_device_handle_t Device, ol_alloc_type_t Type, size_t Size, + void **AllocationOut, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olMemFree that also sets source code location information +/// @details See also ::olMemFree +OL_APIEXPORT ol_result_t OL_APICALL +olMemFreeWithCodeLoc(void *Address, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olMemcpy that also sets source code location information +/// @details See also ::olMemcpy +OL_APIEXPORT ol_result_t OL_APICALL olMemcpyWithCodeLoc( + ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, + void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olCreateQueue that also sets source code location +/// information +/// @details See also ::olCreateQueue +OL_APIEXPORT ol_result_t OL_APICALL +olCreateQueueWithCodeLoc(ol_device_handle_t Device, ol_queue_handle_t *Queue, + ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olDestroyQueue that also sets source code location +/// information +/// @details See also ::olDestroyQueue +OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueueWithCodeLoc( + ol_queue_handle_t Queue, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olWaitQueue that also sets source code location +/// information +/// @details See also ::olWaitQueue +OL_APIEXPORT ol_result_t OL_APICALL olWaitQueueWithCodeLoc( + ol_queue_handle_t Queue, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olDestroyEvent that also sets source code location +/// information +/// @details See also ::olDestroyEvent +OL_APIEXPORT ol_result_t OL_APICALL olDestroyEventWithCodeLoc( + ol_event_handle_t Event, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olWaitEvent that also sets source code location +/// information +/// @details See also ::olWaitEvent +OL_APIEXPORT ol_result_t OL_APICALL olWaitEventWithCodeLoc( + ol_event_handle_t Event, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olCreateProgram that also sets source code location +/// information +/// @details See also ::olCreateProgram +OL_APIEXPORT ol_result_t OL_APICALL olCreateProgramWithCodeLoc( + ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize, + ol_program_handle_t *Program, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olDestroyProgram that also sets source code location +/// information +/// @details See also ::olDestroyProgram +OL_APIEXPORT ol_result_t OL_APICALL olDestroyProgramWithCodeLoc( + ol_program_handle_t Program, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olGetKernel that also sets source code location +/// information +/// @details See also ::olGetKernel +OL_APIEXPORT ol_result_t OL_APICALL olGetKernelWithCodeLoc( + ol_program_handle_t Program, const char *KernelName, + ol_kernel_handle_t *Kernel, ol_code_location_t *CodeLocation); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of olLaunchKernel that also sets source code location +/// information +/// @details See also ::olLaunchKernel +OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernelWithCodeLoc( + ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation); + #if defined(__cplusplus) } // extern "C" #endif diff --git a/offload/liboffload/include/generated/OffloadEntryPoints.inc b/offload/liboffload/include/generated/OffloadEntryPoints.inc index 49c1c8169615..d70ebed934dc 100644 --- a/offload/liboffload/include/generated/OffloadEntryPoints.inc +++ b/offload/liboffload/include/generated/OffloadEntryPoints.inc @@ -8,30 +8,30 @@ /////////////////////////////////////////////////////////////////////////////// ol_impl_result_t olInit_val() { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { } - return olInit_impl(); + return llvm::offload::olInit_impl(); } OL_APIEXPORT ol_result_t OL_APICALL olInit() { if (offloadConfig().TracingEnabled) { - std::cout << "---> olInit"; + llvm::errs() << "---> olInit"; } ol_result_t Result = olInit_val(); if (offloadConfig().TracingEnabled) { - std::cout << "()"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "()"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; } ol_result_t olInitWithCodeLoc(ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olInit(); + ol_result_t Result = ::olInit(); currentCodeLocation() = nullptr; return Result; @@ -39,109 +39,30 @@ ol_result_t olInitWithCodeLoc(ol_code_location_t *CodeLocation) { /////////////////////////////////////////////////////////////////////////////// ol_impl_result_t olShutDown_val() { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { } - return olShutDown_impl(); + return llvm::offload::olShutDown_impl(); } OL_APIEXPORT ol_result_t OL_APICALL olShutDown() { if (offloadConfig().TracingEnabled) { - std::cout << "---> olShutDown"; + llvm::errs() << "---> olShutDown"; } ol_result_t Result = olShutDown_val(); if (offloadConfig().TracingEnabled) { - std::cout << "()"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "()"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; } ol_result_t olShutDownWithCodeLoc(ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olShutDown(); - - currentCodeLocation() = nullptr; - return Result; -} - -/////////////////////////////////////////////////////////////////////////////// -ol_impl_result_t olGetPlatform_val(uint32_t NumEntries, - ol_platform_handle_t *Platforms) { - if (true /*enableParameterValidation*/) { - if (NumEntries == 0) { - return OL_ERRC_INVALID_SIZE; - } - - if (NULL == Platforms) { - return OL_ERRC_INVALID_NULL_POINTER; - } - } - - return olGetPlatform_impl(NumEntries, Platforms); -} -OL_APIEXPORT ol_result_t OL_APICALL -olGetPlatform(uint32_t NumEntries, ol_platform_handle_t *Platforms) { - if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetPlatform"; - } - - ol_result_t Result = olGetPlatform_val(NumEntries, Platforms); - - if (offloadConfig().TracingEnabled) { - ol_get_platform_params_t Params = {&NumEntries, &Platforms}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; - if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; - } - } - return Result; -} -ol_result_t olGetPlatformWithCodeLoc(uint32_t NumEntries, - ol_platform_handle_t *Platforms, - ol_code_location_t *CodeLocation) { - currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetPlatform(NumEntries, Platforms); - - currentCodeLocation() = nullptr; - return Result; -} - -/////////////////////////////////////////////////////////////////////////////// -ol_impl_result_t olGetPlatformCount_val(uint32_t *NumPlatforms) { - if (true /*enableParameterValidation*/) { - if (NULL == NumPlatforms) { - return OL_ERRC_INVALID_NULL_POINTER; - } - } - - return olGetPlatformCount_impl(NumPlatforms); -} -OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCount(uint32_t *NumPlatforms) { - if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetPlatformCount"; - } - - ol_result_t Result = olGetPlatformCount_val(NumPlatforms); - - if (offloadConfig().TracingEnabled) { - ol_get_platform_count_params_t Params = {&NumPlatforms}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; - if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; - } - } - return Result; -} -ol_result_t olGetPlatformCountWithCodeLoc(uint32_t *NumPlatforms, - ol_code_location_t *CodeLocation) { - currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetPlatformCount(NumPlatforms); + ol_result_t Result = ::olShutDown(); currentCodeLocation() = nullptr; return Result; @@ -151,7 +72,7 @@ ol_result_t olGetPlatformCountWithCodeLoc(uint32_t *NumPlatforms, ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue) { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { if (PropSize == 0) { return OL_ERRC_INVALID_SIZE; } @@ -165,13 +86,14 @@ ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform, } } - return olGetPlatformInfo_impl(Platform, PropName, PropSize, PropValue); + return llvm::offload::olGetPlatformInfo_impl(Platform, PropName, PropSize, + PropValue); } OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfo(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue) { if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetPlatformInfo"; + llvm::errs() << "---> olGetPlatformInfo"; } ol_result_t Result = @@ -180,10 +102,10 @@ olGetPlatformInfo(ol_platform_handle_t Platform, ol_platform_info_t PropName, if (offloadConfig().TracingEnabled) { ol_get_platform_info_params_t Params = {&Platform, &PropName, &PropSize, &PropValue}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; @@ -194,7 +116,7 @@ ol_result_t olGetPlatformInfoWithCodeLoc(ol_platform_handle_t Platform, ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; ol_result_t Result = - olGetPlatformInfo(Platform, PropName, PropSize, PropValue); + ::olGetPlatformInfo(Platform, PropName, PropSize, PropValue); currentCodeLocation() = nullptr; return Result; @@ -204,7 +126,7 @@ ol_result_t olGetPlatformInfoWithCodeLoc(ol_platform_handle_t Platform, ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t *PropSizeRet) { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { if (NULL == Platform) { return OL_ERRC_INVALID_NULL_HANDLE; } @@ -214,13 +136,14 @@ ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform, } } - return olGetPlatformInfoSize_impl(Platform, PropName, PropSizeRet); + return llvm::offload::olGetPlatformInfoSize_impl(Platform, PropName, + PropSizeRet); } OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfoSize(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t *PropSizeRet) { if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetPlatformInfoSize"; + llvm::errs() << "---> olGetPlatformInfoSize"; } ol_result_t Result = @@ -229,10 +152,10 @@ olGetPlatformInfoSize(ol_platform_handle_t Platform, if (offloadConfig().TracingEnabled) { ol_get_platform_info_size_params_t Params = {&Platform, &PropName, &PropSizeRet}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; @@ -242,100 +165,43 @@ ol_result_t olGetPlatformInfoSizeWithCodeLoc(ol_platform_handle_t Platform, size_t *PropSizeRet, ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetPlatformInfoSize(Platform, PropName, PropSizeRet); + ol_result_t Result = ::olGetPlatformInfoSize(Platform, PropName, PropSizeRet); currentCodeLocation() = nullptr; return Result; } /////////////////////////////////////////////////////////////////////////////// -ol_impl_result_t olGetDeviceCount_val(ol_platform_handle_t Platform, - uint32_t *NumDevices) { - if (true /*enableParameterValidation*/) { - if (NULL == Platform) { - return OL_ERRC_INVALID_NULL_HANDLE; - } - - if (NULL == NumDevices) { - return OL_ERRC_INVALID_NULL_POINTER; - } +ol_impl_result_t olIterateDevices_val(ol_device_iterate_cb_t Callback, + void *UserData) { + if (offloadConfig().ValidationEnabled) { } - return olGetDeviceCount_impl(Platform, NumDevices); + return llvm::offload::olIterateDevices_impl(Callback, UserData); } OL_APIEXPORT ol_result_t OL_APICALL -olGetDeviceCount(ol_platform_handle_t Platform, uint32_t *NumDevices) { +olIterateDevices(ol_device_iterate_cb_t Callback, void *UserData) { if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetDeviceCount"; + llvm::errs() << "---> olIterateDevices"; } - ol_result_t Result = olGetDeviceCount_val(Platform, NumDevices); + ol_result_t Result = olIterateDevices_val(Callback, UserData); if (offloadConfig().TracingEnabled) { - ol_get_device_count_params_t Params = {&Platform, &NumDevices}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; + ol_iterate_devices_params_t Params = {&Callback, &UserData}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; } -ol_result_t olGetDeviceCountWithCodeLoc(ol_platform_handle_t Platform, - uint32_t *NumDevices, +ol_result_t olIterateDevicesWithCodeLoc(ol_device_iterate_cb_t Callback, + void *UserData, ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetDeviceCount(Platform, NumDevices); - - currentCodeLocation() = nullptr; - return Result; -} - -/////////////////////////////////////////////////////////////////////////////// -ol_impl_result_t olGetDevice_val(ol_platform_handle_t Platform, - uint32_t NumEntries, - ol_device_handle_t *Devices) { - if (true /*enableParameterValidation*/) { - if (NumEntries == 0) { - return OL_ERRC_INVALID_SIZE; - } - - if (NULL == Platform) { - return OL_ERRC_INVALID_NULL_HANDLE; - } - - if (NULL == Devices) { - return OL_ERRC_INVALID_NULL_POINTER; - } - } - - return olGetDevice_impl(Platform, NumEntries, Devices); -} -OL_APIEXPORT ol_result_t OL_APICALL olGetDevice(ol_platform_handle_t Platform, - uint32_t NumEntries, - ol_device_handle_t *Devices) { - if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetDevice"; - } - - ol_result_t Result = olGetDevice_val(Platform, NumEntries, Devices); - - if (offloadConfig().TracingEnabled) { - ol_get_device_params_t Params = {&Platform, &NumEntries, &Devices}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; - if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; - } - } - return Result; -} -ol_result_t olGetDeviceWithCodeLoc(ol_platform_handle_t Platform, - uint32_t NumEntries, - ol_device_handle_t *Devices, - ol_code_location_t *CodeLocation) { - currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetDevice(Platform, NumEntries, Devices); + ol_result_t Result = ::olIterateDevices(Callback, UserData); currentCodeLocation() = nullptr; return Result; @@ -345,7 +211,7 @@ ol_result_t olGetDeviceWithCodeLoc(ol_platform_handle_t Platform, ol_impl_result_t olGetDeviceInfo_val(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { if (PropSize == 0) { return OL_ERRC_INVALID_SIZE; } @@ -359,14 +225,15 @@ ol_impl_result_t olGetDeviceInfo_val(ol_device_handle_t Device, } } - return olGetDeviceInfo_impl(Device, PropName, PropSize, PropValue); + return llvm::offload::olGetDeviceInfo_impl(Device, PropName, PropSize, + PropValue); } OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetDeviceInfo"; + llvm::errs() << "---> olGetDeviceInfo"; } ol_result_t Result = @@ -375,10 +242,10 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(ol_device_handle_t Device, if (offloadConfig().TracingEnabled) { ol_get_device_info_params_t Params = {&Device, &PropName, &PropSize, &PropValue}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; @@ -388,7 +255,7 @@ ol_result_t olGetDeviceInfoWithCodeLoc(ol_device_handle_t Device, size_t PropSize, void *PropValue, ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetDeviceInfo(Device, PropName, PropSize, PropValue); + ol_result_t Result = ::olGetDeviceInfo(Device, PropName, PropSize, PropValue); currentCodeLocation() = nullptr; return Result; @@ -398,7 +265,7 @@ ol_result_t olGetDeviceInfoWithCodeLoc(ol_device_handle_t Device, ol_impl_result_t olGetDeviceInfoSize_val(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { - if (true /*enableParameterValidation*/) { + if (offloadConfig().ValidationEnabled) { if (NULL == Device) { return OL_ERRC_INVALID_NULL_HANDLE; } @@ -408,12 +275,12 @@ ol_impl_result_t olGetDeviceInfoSize_val(ol_device_handle_t Device, } } - return olGetDeviceInfoSize_impl(Device, PropName, PropSizeRet); + return llvm::offload::olGetDeviceInfoSize_impl(Device, PropName, PropSizeRet); } OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize( ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { if (offloadConfig().TracingEnabled) { - std::cout << "---> olGetDeviceInfoSize"; + llvm::errs() << "---> olGetDeviceInfoSize"; } ol_result_t Result = olGetDeviceInfoSize_val(Device, PropName, PropSizeRet); @@ -421,10 +288,10 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize( if (offloadConfig().TracingEnabled) { ol_get_device_info_size_params_t Params = {&Device, &PropName, &PropSizeRet}; - std::cout << "(" << &Params << ")"; - std::cout << "-> " << Result << "\n"; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; if (Result && Result->Details) { - std::cout << " *Error Details* " << Result->Details << " \n"; + llvm::errs() << " *Error Details* " << Result->Details << " \n"; } } return Result; @@ -434,7 +301,559 @@ ol_result_t olGetDeviceInfoSizeWithCodeLoc(ol_device_handle_t Device, size_t *PropSizeRet, ol_code_location_t *CodeLocation) { currentCodeLocation() = CodeLocation; - ol_result_t Result = olGetDeviceInfoSize(Device, PropName, PropSizeRet); + ol_result_t Result = ::olGetDeviceInfoSize(Device, PropName, PropSizeRet); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olMemAlloc_val(ol_device_handle_t Device, ol_alloc_type_t Type, + size_t Size, void **AllocationOut) { + if (offloadConfig().ValidationEnabled) { + if (Size == 0) { + return OL_ERRC_INVALID_SIZE; + } + + if (NULL == Device) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == AllocationOut) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olMemAlloc_impl(Device, Type, Size, AllocationOut); +} +OL_APIEXPORT ol_result_t OL_APICALL olMemAlloc(ol_device_handle_t Device, + ol_alloc_type_t Type, + size_t Size, + void **AllocationOut) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olMemAlloc"; + } + + ol_result_t Result = olMemAlloc_val(Device, Type, Size, AllocationOut); + + if (offloadConfig().TracingEnabled) { + ol_mem_alloc_params_t Params = {&Device, &Type, &Size, &AllocationOut}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olMemAllocWithCodeLoc(ol_device_handle_t Device, + ol_alloc_type_t Type, size_t Size, + void **AllocationOut, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olMemAlloc(Device, Type, Size, AllocationOut); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olMemFree_val(void *Address) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Address) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olMemFree_impl(Address); +} +OL_APIEXPORT ol_result_t OL_APICALL olMemFree(void *Address) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olMemFree"; + } + + ol_result_t Result = olMemFree_val(Address); + + if (offloadConfig().TracingEnabled) { + ol_mem_free_params_t Params = {&Address}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olMemFreeWithCodeLoc(void *Address, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olMemFree(Address); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olMemcpy_val(ol_queue_handle_t Queue, void *DstPtr, + ol_device_handle_t DstDevice, void *SrcPtr, + ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut) { + if (offloadConfig().ValidationEnabled) { + if (Queue == NULL && EventOut != NULL) { + return OL_ERRC_INVALID_ARGUMENT; + } + + if (NULL == DstDevice) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == SrcDevice) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == DstPtr) { + return OL_ERRC_INVALID_NULL_POINTER; + } + + if (NULL == SrcPtr) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olMemcpy_impl(Queue, DstPtr, DstDevice, SrcPtr, + SrcDevice, Size, EventOut); +} +OL_APIEXPORT ol_result_t OL_APICALL +olMemcpy(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, + void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olMemcpy"; + } + + ol_result_t Result = + olMemcpy_val(Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut); + + if (offloadConfig().TracingEnabled) { + ol_memcpy_params_t Params = {&Queue, &DstPtr, &DstDevice, &SrcPtr, + &SrcDevice, &Size, &EventOut}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olMemcpyWithCodeLoc(ol_queue_handle_t Queue, void *DstPtr, + ol_device_handle_t DstDevice, void *SrcPtr, + ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = + ::olMemcpy(Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olCreateQueue_val(ol_device_handle_t Device, + ol_queue_handle_t *Queue) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Device) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == Queue) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olCreateQueue_impl(Device, Queue); +} +OL_APIEXPORT ol_result_t OL_APICALL olCreateQueue(ol_device_handle_t Device, + ol_queue_handle_t *Queue) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olCreateQueue"; + } + + ol_result_t Result = olCreateQueue_val(Device, Queue); + + if (offloadConfig().TracingEnabled) { + ol_create_queue_params_t Params = {&Device, &Queue}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olCreateQueueWithCodeLoc(ol_device_handle_t Device, + ol_queue_handle_t *Queue, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olCreateQueue(Device, Queue); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olDestroyQueue_val(ol_queue_handle_t Queue) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Queue) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + } + + return llvm::offload::olDestroyQueue_impl(Queue); +} +OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueue(ol_queue_handle_t Queue) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olDestroyQueue"; + } + + ol_result_t Result = olDestroyQueue_val(Queue); + + if (offloadConfig().TracingEnabled) { + ol_destroy_queue_params_t Params = {&Queue}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olDestroyQueueWithCodeLoc(ol_queue_handle_t Queue, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olDestroyQueue(Queue); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olWaitQueue_val(ol_queue_handle_t Queue) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Queue) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + } + + return llvm::offload::olWaitQueue_impl(Queue); +} +OL_APIEXPORT ol_result_t OL_APICALL olWaitQueue(ol_queue_handle_t Queue) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olWaitQueue"; + } + + ol_result_t Result = olWaitQueue_val(Queue); + + if (offloadConfig().TracingEnabled) { + ol_wait_queue_params_t Params = {&Queue}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olWaitQueueWithCodeLoc(ol_queue_handle_t Queue, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olWaitQueue(Queue); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olDestroyEvent_val(ol_event_handle_t Event) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Event) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + } + + return llvm::offload::olDestroyEvent_impl(Event); +} +OL_APIEXPORT ol_result_t OL_APICALL olDestroyEvent(ol_event_handle_t Event) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olDestroyEvent"; + } + + ol_result_t Result = olDestroyEvent_val(Event); + + if (offloadConfig().TracingEnabled) { + ol_destroy_event_params_t Params = {&Event}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olDestroyEventWithCodeLoc(ol_event_handle_t Event, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olDestroyEvent(Event); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olWaitEvent_val(ol_event_handle_t Event) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Event) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + } + + return llvm::offload::olWaitEvent_impl(Event); +} +OL_APIEXPORT ol_result_t OL_APICALL olWaitEvent(ol_event_handle_t Event) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olWaitEvent"; + } + + ol_result_t Result = olWaitEvent_val(Event); + + if (offloadConfig().TracingEnabled) { + ol_wait_event_params_t Params = {&Event}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olWaitEventWithCodeLoc(ol_event_handle_t Event, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olWaitEvent(Event); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olCreateProgram_val(ol_device_handle_t Device, + const void *ProgData, size_t ProgDataSize, + ol_program_handle_t *Program) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Device) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == ProgData) { + return OL_ERRC_INVALID_NULL_POINTER; + } + + if (NULL == Program) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olCreateProgram_impl(Device, ProgData, ProgDataSize, + Program); +} +OL_APIEXPORT ol_result_t OL_APICALL +olCreateProgram(ol_device_handle_t Device, const void *ProgData, + size_t ProgDataSize, ol_program_handle_t *Program) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olCreateProgram"; + } + + ol_result_t Result = + olCreateProgram_val(Device, ProgData, ProgDataSize, Program); + + if (offloadConfig().TracingEnabled) { + ol_create_program_params_t Params = {&Device, &ProgData, &ProgDataSize, + &Program}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olCreateProgramWithCodeLoc(ol_device_handle_t Device, + const void *ProgData, + size_t ProgDataSize, + ol_program_handle_t *Program, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = + ::olCreateProgram(Device, ProgData, ProgDataSize, Program); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olDestroyProgram_val(ol_program_handle_t Program) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Program) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + } + + return llvm::offload::olDestroyProgram_impl(Program); +} +OL_APIEXPORT ol_result_t OL_APICALL +olDestroyProgram(ol_program_handle_t Program) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olDestroyProgram"; + } + + ol_result_t Result = olDestroyProgram_val(Program); + + if (offloadConfig().TracingEnabled) { + ol_destroy_program_params_t Params = {&Program}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olDestroyProgramWithCodeLoc(ol_program_handle_t Program, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olDestroyProgram(Program); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t olGetKernel_val(ol_program_handle_t Program, + const char *KernelName, + ol_kernel_handle_t *Kernel) { + if (offloadConfig().ValidationEnabled) { + if (NULL == Program) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == KernelName) { + return OL_ERRC_INVALID_NULL_POINTER; + } + + if (NULL == Kernel) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olGetKernel_impl(Program, KernelName, Kernel); +} +OL_APIEXPORT ol_result_t OL_APICALL olGetKernel(ol_program_handle_t Program, + const char *KernelName, + ol_kernel_handle_t *Kernel) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olGetKernel"; + } + + ol_result_t Result = olGetKernel_val(Program, KernelName, Kernel); + + if (offloadConfig().TracingEnabled) { + ol_get_kernel_params_t Params = {&Program, &KernelName, &Kernel}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olGetKernelWithCodeLoc(ol_program_handle_t Program, + const char *KernelName, + ol_kernel_handle_t *Kernel, + ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = ::olGetKernel(Program, KernelName, Kernel); + + currentCodeLocation() = nullptr; + return Result; +} + +/////////////////////////////////////////////////////////////////////////////// +ol_impl_result_t +olLaunchKernel_val(ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, + size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut) { + if (offloadConfig().ValidationEnabled) { + if (Queue == NULL && EventOut != NULL) { + return OL_ERRC_INVALID_ARGUMENT; + } + + if (NULL == Device) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == Kernel) { + return OL_ERRC_INVALID_NULL_HANDLE; + } + + if (NULL == ArgumentsData) { + return OL_ERRC_INVALID_NULL_POINTER; + } + + if (NULL == LaunchSizeArgs) { + return OL_ERRC_INVALID_NULL_POINTER; + } + } + + return llvm::offload::olLaunchKernel_impl(Queue, Device, Kernel, + ArgumentsData, ArgumentsSize, + LaunchSizeArgs, EventOut); +} +OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernel( + ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut) { + if (offloadConfig().TracingEnabled) { + llvm::errs() << "---> olLaunchKernel"; + } + + ol_result_t Result = + olLaunchKernel_val(Queue, Device, Kernel, ArgumentsData, ArgumentsSize, + LaunchSizeArgs, EventOut); + + if (offloadConfig().TracingEnabled) { + ol_launch_kernel_params_t Params = { + &Queue, &Device, &Kernel, &ArgumentsData, + &ArgumentsSize, &LaunchSizeArgs, &EventOut}; + llvm::errs() << "(" << &Params << ")"; + llvm::errs() << "-> " << Result << "\n"; + if (Result && Result->Details) { + llvm::errs() << " *Error Details* " << Result->Details << " \n"; + } + } + return Result; +} +ol_result_t olLaunchKernelWithCodeLoc( + ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation) { + currentCodeLocation() = CodeLocation; + ol_result_t Result = + ::olLaunchKernel(Queue, Device, Kernel, ArgumentsData, ArgumentsSize, + LaunchSizeArgs, EventOut); currentCodeLocation() = nullptr; return Result; diff --git a/offload/liboffload/include/generated/OffloadFuncs.inc b/offload/liboffload/include/generated/OffloadFuncs.inc index 48115493c790..78ff9ddb8279 100644 --- a/offload/liboffload/include/generated/OffloadFuncs.inc +++ b/offload/liboffload/include/generated/OffloadFuncs.inc @@ -12,23 +12,41 @@ OFFLOAD_FUNC(olInit) OFFLOAD_FUNC(olShutDown) -OFFLOAD_FUNC(olGetPlatform) -OFFLOAD_FUNC(olGetPlatformCount) OFFLOAD_FUNC(olGetPlatformInfo) OFFLOAD_FUNC(olGetPlatformInfoSize) -OFFLOAD_FUNC(olGetDeviceCount) -OFFLOAD_FUNC(olGetDevice) +OFFLOAD_FUNC(olIterateDevices) OFFLOAD_FUNC(olGetDeviceInfo) OFFLOAD_FUNC(olGetDeviceInfoSize) +OFFLOAD_FUNC(olMemAlloc) +OFFLOAD_FUNC(olMemFree) +OFFLOAD_FUNC(olMemcpy) +OFFLOAD_FUNC(olCreateQueue) +OFFLOAD_FUNC(olDestroyQueue) +OFFLOAD_FUNC(olWaitQueue) +OFFLOAD_FUNC(olDestroyEvent) +OFFLOAD_FUNC(olWaitEvent) +OFFLOAD_FUNC(olCreateProgram) +OFFLOAD_FUNC(olDestroyProgram) +OFFLOAD_FUNC(olGetKernel) +OFFLOAD_FUNC(olLaunchKernel) OFFLOAD_FUNC(olInitWithCodeLoc) OFFLOAD_FUNC(olShutDownWithCodeLoc) -OFFLOAD_FUNC(olGetPlatformWithCodeLoc) -OFFLOAD_FUNC(olGetPlatformCountWithCodeLoc) OFFLOAD_FUNC(olGetPlatformInfoWithCodeLoc) OFFLOAD_FUNC(olGetPlatformInfoSizeWithCodeLoc) -OFFLOAD_FUNC(olGetDeviceCountWithCodeLoc) -OFFLOAD_FUNC(olGetDeviceWithCodeLoc) +OFFLOAD_FUNC(olIterateDevicesWithCodeLoc) OFFLOAD_FUNC(olGetDeviceInfoWithCodeLoc) OFFLOAD_FUNC(olGetDeviceInfoSizeWithCodeLoc) +OFFLOAD_FUNC(olMemAllocWithCodeLoc) +OFFLOAD_FUNC(olMemFreeWithCodeLoc) +OFFLOAD_FUNC(olMemcpyWithCodeLoc) +OFFLOAD_FUNC(olCreateQueueWithCodeLoc) +OFFLOAD_FUNC(olDestroyQueueWithCodeLoc) +OFFLOAD_FUNC(olWaitQueueWithCodeLoc) +OFFLOAD_FUNC(olDestroyEventWithCodeLoc) +OFFLOAD_FUNC(olWaitEventWithCodeLoc) +OFFLOAD_FUNC(olCreateProgramWithCodeLoc) +OFFLOAD_FUNC(olDestroyProgramWithCodeLoc) +OFFLOAD_FUNC(olGetKernelWithCodeLoc) +OFFLOAD_FUNC(olLaunchKernelWithCodeLoc) #undef OFFLOAD_FUNC diff --git a/offload/liboffload/include/generated/OffloadImplFuncDecls.inc b/offload/liboffload/include/generated/OffloadImplFuncDecls.inc index 5b26b2653a05..ced659c2a4bd 100644 --- a/offload/liboffload/include/generated/OffloadImplFuncDecls.inc +++ b/offload/liboffload/include/generated/OffloadImplFuncDecls.inc @@ -9,11 +9,6 @@ ol_impl_result_t olInit_impl(); ol_impl_result_t olShutDown_impl(); -ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries, - ol_platform_handle_t *Platforms); - -ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms); - ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue); @@ -22,12 +17,8 @@ ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t *PropSizeRet); -ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform, - uint32_t *NumDevices); - -ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform, - uint32_t NumEntries, - ol_device_handle_t *Devices); +ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback, + void *UserData); ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, @@ -36,3 +27,42 @@ ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device, ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet); + +ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device, + ol_alloc_type_t Type, size_t Size, + void **AllocationOut); + +ol_impl_result_t olMemFree_impl(void *Address); + +ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, + ol_device_handle_t DstDevice, void *SrcPtr, + ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut); + +ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device, + ol_queue_handle_t *Queue); + +ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue); + +ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue); + +ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event); + +ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event); + +ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device, + const void *ProgData, size_t ProgDataSize, + ol_program_handle_t *Program); + +ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program); + +ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program, + const char *KernelName, + ol_kernel_handle_t *Kernel); + +ol_impl_result_t +olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, + size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut); diff --git a/offload/liboffload/include/generated/OffloadPrint.hpp b/offload/liboffload/include/generated/OffloadPrint.hpp index 8981bb054a4c..7f5e33aea6f7 100644 --- a/offload/liboffload/include/generated/OffloadPrint.hpp +++ b/offload/liboffload/include/generated/OffloadPrint.hpp @@ -11,31 +11,40 @@ #pragma once #include -#include +#include template -inline ol_result_t printPtr(std::ostream &os, const T *ptr); +inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr); template -inline void printTagged(std::ostream &os, const void *ptr, T value, +inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value, size_t size); template struct is_handle : std::false_type {}; template <> struct is_handle : std::true_type {}; template <> struct is_handle : std::true_type {}; template <> struct is_handle : std::true_type {}; +template <> struct is_handle : std::true_type {}; +template <> struct is_handle : std::true_type {}; +template <> struct is_handle : std::true_type {}; template inline constexpr bool is_handle_v = is_handle::value; -inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value); -inline std::ostream &operator<<(std::ostream &os, - enum ol_platform_info_t value); -inline std::ostream &operator<<(std::ostream &os, - enum ol_platform_backend_t value); -inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value); -inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_errc_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_platform_info_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_platform_backend_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_device_type_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_device_info_t value); +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_alloc_type_t value); /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ol_errc_t type -/// @returns std::ostream & -inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) { +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_errc_t value) { switch (value) { case OL_ERRC_SUCCESS: os << "OL_ERRC_SUCCESS"; @@ -46,24 +55,21 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) { case OL_ERRC_INVALID_PLATFORM: os << "OL_ERRC_INVALID_PLATFORM"; break; - case OL_ERRC_DEVICE_NOT_FOUND: - os << "OL_ERRC_DEVICE_NOT_FOUND"; - break; case OL_ERRC_INVALID_DEVICE: os << "OL_ERRC_INVALID_DEVICE"; break; - case OL_ERRC_DEVICE_LOST: - os << "OL_ERRC_DEVICE_LOST"; + case OL_ERRC_INVALID_QUEUE: + os << "OL_ERRC_INVALID_QUEUE"; break; - case OL_ERRC_UNINITIALIZED: - os << "OL_ERRC_UNINITIALIZED"; + case OL_ERRC_INVALID_EVENT: + os << "OL_ERRC_INVALID_EVENT"; + break; + case OL_ERRC_INVALID_KERNEL_NAME: + os << "OL_ERRC_INVALID_KERNEL_NAME"; break; case OL_ERRC_OUT_OF_RESOURCES: os << "OL_ERRC_OUT_OF_RESOURCES"; break; - case OL_ERRC_UNSUPPORTED_VERSION: - os << "OL_ERRC_UNSUPPORTED_VERSION"; - break; case OL_ERRC_UNSUPPORTED_FEATURE: os << "OL_ERRC_UNSUPPORTED_FEATURE"; break; @@ -97,9 +103,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) { /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ol_platform_info_t type -/// @returns std::ostream & -inline std::ostream &operator<<(std::ostream &os, - enum ol_platform_info_t value) { +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_platform_info_t value) { switch (value) { case OL_PLATFORM_INFO_NAME: os << "OL_PLATFORM_INFO_NAME"; @@ -122,9 +128,9 @@ inline std::ostream &operator<<(std::ostream &os, /////////////////////////////////////////////////////////////////////////////// /// @brief Print type-tagged ol_platform_info_t enum value -/// @returns std::ostream & +/// @returns llvm::raw_ostream & template <> -inline void printTagged(std::ostream &os, const void *ptr, +inline void printTagged(llvm::raw_ostream &os, const void *ptr, ol_platform_info_t value, size_t size) { if (ptr == NULL) { printPtr(os, ptr); @@ -159,9 +165,9 @@ inline void printTagged(std::ostream &os, const void *ptr, } /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ol_platform_backend_t type -/// @returns std::ostream & -inline std::ostream &operator<<(std::ostream &os, - enum ol_platform_backend_t value) { +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_platform_backend_t value) { switch (value) { case OL_PLATFORM_BACKEND_UNKNOWN: os << "OL_PLATFORM_BACKEND_UNKNOWN"; @@ -172,6 +178,9 @@ inline std::ostream &operator<<(std::ostream &os, case OL_PLATFORM_BACKEND_AMDGPU: os << "OL_PLATFORM_BACKEND_AMDGPU"; break; + case OL_PLATFORM_BACKEND_HOST: + os << "OL_PLATFORM_BACKEND_HOST"; + break; default: os << "unknown enumerator"; break; @@ -181,8 +190,9 @@ inline std::ostream &operator<<(std::ostream &os, /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ol_device_type_t type -/// @returns std::ostream & -inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value) { +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_device_type_t value) { switch (value) { case OL_DEVICE_TYPE_DEFAULT: os << "OL_DEVICE_TYPE_DEFAULT"; @@ -205,8 +215,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value) { /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ol_device_info_t type -/// @returns std::ostream & -inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value) { +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_device_info_t value) { switch (value) { case OL_DEVICE_INFO_TYPE: os << "OL_DEVICE_INFO_TYPE"; @@ -232,9 +243,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value) { /////////////////////////////////////////////////////////////////////////////// /// @brief Print type-tagged ol_device_info_t enum value -/// @returns std::ostream & +/// @returns llvm::raw_ostream & template <> -inline void printTagged(std::ostream &os, const void *ptr, +inline void printTagged(llvm::raw_ostream &os, const void *ptr, ol_device_info_t value, size_t size) { if (ptr == NULL) { printPtr(os, ptr); @@ -274,9 +285,30 @@ inline void printTagged(std::ostream &os, const void *ptr, break; } } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ol_alloc_type_t type +/// @returns llvm::raw_ostream & +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + enum ol_alloc_type_t value) { + switch (value) { + case OL_ALLOC_TYPE_HOST: + os << "OL_ALLOC_TYPE_HOST"; + break; + case OL_ALLOC_TYPE_DEVICE: + os << "OL_ALLOC_TYPE_DEVICE"; + break; + case OL_ALLOC_TYPE_MANAGED: + os << "OL_ALLOC_TYPE_MANAGED"; + break; + default: + os << "unknown enumerator"; + break; + } + return os; +} -inline std::ostream &operator<<(std::ostream &os, - const ol_error_struct_t *Err) { +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ol_error_struct_t *Err) { if (Err == nullptr) { os << "OL_SUCCESS"; } else { @@ -284,34 +316,64 @@ inline std::ostream &operator<<(std::ostream &os, } return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ol_code_location_t type +/// @returns llvm::raw_ostream & -inline std::ostream &operator<<(std::ostream &os, - const struct ol_get_platform_params_t *params) { - os << ".NumEntries = "; - os << *params->pNumEntries; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const struct ol_code_location_t params) { + os << "(struct ol_code_location_t){"; + os << ".FunctionName = "; + printPtr(os, params.FunctionName); os << ", "; - os << ".Platforms = "; - os << "{"; - for (size_t i = 0; i < *params->pNumEntries; i++) { - if (i > 0) { - os << ", "; - } - printPtr(os, (*params->pPlatforms)[i]); - } + os << ".SourceFile = "; + printPtr(os, params.SourceFile); + os << ", "; + os << ".LineNumber = "; + os << params.LineNumber; + os << ", "; + os << ".ColumnNumber = "; + os << params.ColumnNumber; + os << "}"; + return os; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ol_kernel_launch_size_args_t type +/// @returns llvm::raw_ostream & + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_kernel_launch_size_args_t params) { + os << "(struct ol_kernel_launch_size_args_t){"; + os << ".Dimensions = "; + os << params.Dimensions; + os << ", "; + os << ".NumGroupsX = "; + os << params.NumGroupsX; + os << ", "; + os << ".NumGroupsY = "; + os << params.NumGroupsY; + os << ", "; + os << ".NumGroupsZ = "; + os << params.NumGroupsZ; + os << ", "; + os << ".GroupSizeX = "; + os << params.GroupSizeX; + os << ", "; + os << ".GroupSizeY = "; + os << params.GroupSizeY; + os << ", "; + os << ".GroupSizeZ = "; + os << params.GroupSizeZ; + os << ", "; + os << ".DynSharedMemory = "; + os << params.DynSharedMemory; os << "}"; return os; } -inline std::ostream & -operator<<(std::ostream &os, - const struct ol_get_platform_count_params_t *params) { - os << ".NumPlatforms = "; - printPtr(os, *params->pNumPlatforms); - return os; -} - -inline std::ostream & -operator<<(std::ostream &os, +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_get_platform_info_params_t *params) { os << ".Platform = "; printPtr(os, *params->pPlatform); @@ -327,8 +389,8 @@ operator<<(std::ostream &os, return os; } -inline std::ostream & -operator<<(std::ostream &os, +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_get_platform_info_size_params_t *params) { os << ".Platform = "; printPtr(os, *params->pPlatform); @@ -341,39 +403,20 @@ operator<<(std::ostream &os, return os; } -inline std::ostream & -operator<<(std::ostream &os, - const struct ol_get_device_count_params_t *params) { - os << ".Platform = "; - printPtr(os, *params->pPlatform); +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_iterate_devices_params_t *params) { + os << ".Callback = "; + os << reinterpret_cast(*params->pCallback); os << ", "; - os << ".NumDevices = "; - printPtr(os, *params->pNumDevices); + os << ".UserData = "; + printPtr(os, *params->pUserData); return os; } -inline std::ostream &operator<<(std::ostream &os, - const struct ol_get_device_params_t *params) { - os << ".Platform = "; - printPtr(os, *params->pPlatform); - os << ", "; - os << ".NumEntries = "; - os << *params->pNumEntries; - os << ", "; - os << ".Devices = "; - os << "{"; - for (size_t i = 0; i < *params->pNumEntries; i++) { - if (i > 0) { - os << ", "; - } - printPtr(os, (*params->pDevices)[i]); - } - os << "}"; - return os; -} - -inline std::ostream & -operator<<(std::ostream &os, const struct ol_get_device_info_params_t *params) { +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_get_device_info_params_t *params) { os << ".Device = "; printPtr(os, *params->pDevice); os << ", "; @@ -388,8 +431,8 @@ operator<<(std::ostream &os, const struct ol_get_device_info_params_t *params) { return os; } -inline std::ostream & -operator<<(std::ostream &os, +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_get_device_info_size_params_t *params) { os << ".Device = "; printPtr(os, *params->pDevice); @@ -402,10 +445,163 @@ operator<<(std::ostream &os, return os; } +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_mem_alloc_params_t *params) { + os << ".Device = "; + printPtr(os, *params->pDevice); + os << ", "; + os << ".Type = "; + os << *params->pType; + os << ", "; + os << ".Size = "; + os << *params->pSize; + os << ", "; + os << ".AllocationOut = "; + printPtr(os, *params->pAllocationOut); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_mem_free_params_t *params) { + os << ".Address = "; + printPtr(os, *params->pAddress); + return os; +} + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const struct ol_memcpy_params_t *params) { + os << ".Queue = "; + printPtr(os, *params->pQueue); + os << ", "; + os << ".DstPtr = "; + printPtr(os, *params->pDstPtr); + os << ", "; + os << ".DstDevice = "; + printPtr(os, *params->pDstDevice); + os << ", "; + os << ".SrcPtr = "; + printPtr(os, *params->pSrcPtr); + os << ", "; + os << ".SrcDevice = "; + printPtr(os, *params->pSrcDevice); + os << ", "; + os << ".Size = "; + os << *params->pSize; + os << ", "; + os << ".EventOut = "; + printPtr(os, *params->pEventOut); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_create_queue_params_t *params) { + os << ".Device = "; + printPtr(os, *params->pDevice); + os << ", "; + os << ".Queue = "; + printPtr(os, *params->pQueue); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_destroy_queue_params_t *params) { + os << ".Queue = "; + printPtr(os, *params->pQueue); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_wait_queue_params_t *params) { + os << ".Queue = "; + printPtr(os, *params->pQueue); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_destroy_event_params_t *params) { + os << ".Event = "; + printPtr(os, *params->pEvent); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_wait_event_params_t *params) { + os << ".Event = "; + printPtr(os, *params->pEvent); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_create_program_params_t *params) { + os << ".Device = "; + printPtr(os, *params->pDevice); + os << ", "; + os << ".ProgData = "; + printPtr(os, *params->pProgData); + os << ", "; + os << ".ProgDataSize = "; + os << *params->pProgDataSize; + os << ", "; + os << ".Program = "; + printPtr(os, *params->pProgram); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_destroy_program_params_t *params) { + os << ".Program = "; + printPtr(os, *params->pProgram); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const struct ol_get_kernel_params_t *params) { + os << ".Program = "; + printPtr(os, *params->pProgram); + os << ", "; + os << ".KernelName = "; + printPtr(os, *params->pKernelName); + os << ", "; + os << ".Kernel = "; + printPtr(os, *params->pKernel); + return os; +} + +inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const struct ol_launch_kernel_params_t *params) { + os << ".Queue = "; + printPtr(os, *params->pQueue); + os << ", "; + os << ".Device = "; + printPtr(os, *params->pDevice); + os << ", "; + os << ".Kernel = "; + printPtr(os, *params->pKernel); + os << ", "; + os << ".ArgumentsData = "; + printPtr(os, *params->pArgumentsData); + os << ", "; + os << ".ArgumentsSize = "; + os << *params->pArgumentsSize; + os << ", "; + os << ".LaunchSizeArgs = "; + printPtr(os, *params->pLaunchSizeArgs); + os << ", "; + os << ".EventOut = "; + printPtr(os, *params->pEventOut); + return os; +} + /////////////////////////////////////////////////////////////////////////////// // @brief Print pointer value template -inline ol_result_t printPtr(std::ostream &os, const T *ptr) { +inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) { if (ptr == nullptr) { os << "nullptr"; } else if constexpr (std::is_pointer_v) { diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 457f1053f163..d956d274b5eb 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -19,27 +19,6 @@ #include -using namespace llvm; -using namespace llvm::omp::target::plugin; - -// Handle type definitions. Ideally these would be 1:1 with the plugins -struct ol_device_handle_t_ { - int DeviceNum; - GenericDeviceTy &Device; - ol_platform_handle_t Platform; -}; - -struct ol_platform_handle_t_ { - std::unique_ptr Plugin; - std::vector Devices; -}; - -using PlatformVecT = SmallVector; -PlatformVecT &Platforms() { - static PlatformVecT Platforms; - return Platforms; -} - // TODO: Some plugins expect to be linked into libomptarget which defines these // symbols to implement ompt callbacks. The least invasive workaround here is to // define them in libLLVMOffload as false/null so they are never used. In future @@ -55,6 +34,97 @@ ompt_function_lookup_t lookupCallbackByName = nullptr; } // namespace llvm::omp::target #endif +using namespace llvm::omp::target; +using namespace llvm::omp::target::plugin; + +// Handle type definitions. Ideally these would be 1:1 with the plugins, but +// we add some additional data here for now to avoid churn in the plugin +// interface. +struct ol_device_impl_t { + ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, + ol_platform_handle_t Platform) + : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {} + int DeviceNum; + GenericDeviceTy *Device; + ol_platform_handle_t Platform; +}; + +struct ol_platform_impl_t { + ol_platform_impl_t(std::unique_ptr Plugin, + std::vector Devices, + ol_platform_backend_t BackendType) + : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {} + std::unique_ptr Plugin; + std::vector Devices; + ol_platform_backend_t BackendType; +}; + +struct ol_queue_impl_t { + ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) + : AsyncInfo(AsyncInfo), Device(Device) {} + __tgt_async_info *AsyncInfo; + ol_device_handle_t Device; +}; + +struct ol_event_impl_t { + ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue) + : EventInfo(EventInfo), Queue(Queue) {} + ~ol_event_impl_t() { (void)Queue->Device->Device->destroyEvent(EventInfo); } + void *EventInfo; + ol_queue_handle_t Queue; +}; + +struct ol_program_impl_t { + ol_program_impl_t(plugin::DeviceImageTy *Image, + std::unique_ptr ImageData, + const __tgt_device_image &DeviceImage) + : Image(Image), ImageData(std::move(ImageData)), + DeviceImage(DeviceImage) {} + plugin::DeviceImageTy *Image; + std::unique_ptr ImageData; + __tgt_device_image DeviceImage; +}; + +namespace llvm { +namespace offload { + +struct AllocInfo { + ol_device_handle_t Device; + ol_alloc_type_t Type; +}; + +using AllocInfoMapT = DenseMap; +AllocInfoMapT &allocInfoMap() { + static AllocInfoMapT AllocInfoMap{}; + return AllocInfoMap; +} + +using PlatformVecT = SmallVector; +PlatformVecT &Platforms() { + static PlatformVecT Platforms; + return Platforms; +} + +ol_device_handle_t HostDevice() { + // The host platform is always inserted last + return &Platforms().back().Devices[0]; +} + +template ol_impl_result_t olDestroy(HandleT Handle) { + delete Handle; + return OL_SUCCESS; +} + +constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { + if (Name == "amdgpu") { + return OL_PLATFORM_BACKEND_AMDGPU; + } else if (Name == "cuda") { + return OL_PLATFORM_BACKEND_CUDA; + } else { + return OL_PLATFORM_BACKEND_UNKNOWN; + } +} + // Every plugin exports this method to create an instance of the plugin type. #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" @@ -63,26 +133,36 @@ void initPlugins() { // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Platforms().emplace_back(ol_platform_handle_t_{ \ - std::unique_ptr(createPlugin_##Name()), {}}); \ + Platforms().emplace_back(ol_platform_impl_t{ \ + std::unique_ptr(createPlugin_##Name()), \ + {}, \ + pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" - // Preemptively initialize all devices in the plugin so we can just return - // them from deviceGet + // Preemptively initialize all devices in the plugin for (auto &Platform : Platforms()) { auto Err = Platform.Plugin->init(); [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); DevNum++) { if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - Platform.Devices.emplace_back(ol_device_handle_t_{ - DevNum, Platform.Plugin->getDevice(DevNum), &Platform}); + Platform.Devices.emplace_back(ol_device_impl_t{ + DevNum, &Platform.Plugin->getDevice(DevNum), &Platform}); } } } + // Add the special host device + auto &HostPlatform = Platforms().emplace_back( + ol_platform_impl_t{nullptr, + {ol_device_impl_t{-1, nullptr, nullptr}}, + OL_PLATFORM_BACKEND_HOST}); + HostDevice()->Platform = &HostPlatform; + offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE"); + offloadConfig().ValidationEnabled = + !std::getenv("OFFLOAD_DISABLE_VALIDATION"); } // TODO: We can properly reference count here and manage the resources in a more @@ -95,36 +175,16 @@ ol_impl_result_t olInit_impl() { } ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; } -ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) { - *NumPlatforms = Platforms().size(); - return OL_SUCCESS; -} - -ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries, - ol_platform_handle_t *PlatformsOut) { - if (NumEntries > Platforms().size()) { - return {OL_ERRC_INVALID_SIZE, - std::string{formatv("{0} platform(s) available but {1} requested.", - Platforms().size(), NumEntries)}}; - } - - for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries; - PlatformIndex++) { - PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex]; - } - - return OL_SUCCESS; -} - ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; switch (PropName) { case OL_PLATFORM_INFO_NAME: - return ReturnValue(Platform->Plugin->getName()); + return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); case OL_PLATFORM_INFO_VENDOR_NAME: // TODO: Implement this return ReturnValue("Unknown platform vendor"); @@ -135,14 +195,7 @@ ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, .c_str()); } case OL_PLATFORM_INFO_BACKEND: { - auto PluginName = Platform->Plugin->getName(); - if (PluginName == StringRef("CUDA")) { - return ReturnValue(OL_PLATFORM_BACKEND_CUDA); - } else if (PluginName == StringRef("AMDGPU")) { - return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU); - } else { - return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN); - } + return ReturnValue(Platform->BackendType); } default: return OL_ERRC_INVALID_ENUMERATION; @@ -165,27 +218,6 @@ ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, PropSizeRet); } -ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform, - uint32_t *pNumDevices) { - *pNumDevices = static_cast(Platform->Devices.size()); - - return OL_SUCCESS; -} - -ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform, - uint32_t NumEntries, - ol_device_handle_t *Devices) { - if (NumEntries > Platform->Devices.size()) { - return OL_ERRC_INVALID_SIZE; - } - - for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) { - Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]); - } - - return OL_SUCCESS; -} - ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, @@ -193,12 +225,12 @@ ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device, ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); - InfoQueueTy DevInfo; - if (auto Err = Device->Device.obtainInfoImpl(DevInfo)) - return OL_ERRC_OUT_OF_RESOURCES; - // Find the info if it exists under any of the given names - auto GetInfo = [&DevInfo](std::vector Names) { + auto GetInfo = [&](std::vector Names) { + InfoQueueTy DevInfo; + if (auto Err = Device->Device->obtainInfoImpl(DevInfo)) + return std::string(""); + for (auto Name : Names) { auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) { return Info.Key == Name; @@ -245,3 +277,256 @@ ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device, size_t *PropSizeRet) { return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); } + +ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback, + void *UserData) { + for (auto &Platform : Platforms()) { + for (auto &Device : Platform.Devices) { + if (!Callback(&Device, UserData)) { + break; + } + } + } + + return OL_SUCCESS; +} + +TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) { + switch (Type) { + case OL_ALLOC_TYPE_DEVICE: + return TARGET_ALLOC_DEVICE; + case OL_ALLOC_TYPE_HOST: + return TARGET_ALLOC_HOST; + case OL_ALLOC_TYPE_MANAGED: + default: + return TARGET_ALLOC_SHARED; + } +} + +ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device, + ol_alloc_type_t Type, size_t Size, + void **AllocationOut) { + auto Alloc = + Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type)); + if (!Alloc) + return {OL_ERRC_OUT_OF_RESOURCES, + formatv("Could not create allocation on device {0}", Device).str()}; + + *AllocationOut = *Alloc; + allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type}); + return OL_SUCCESS; +} + +ol_impl_result_t olMemFree_impl(void *Address) { + if (!allocInfoMap().contains(Address)) + return {OL_ERRC_INVALID_ARGUMENT, "Address is not a known allocation"}; + + auto AllocInfo = allocInfoMap().at(Address); + auto Device = AllocInfo.Device; + auto Type = AllocInfo.Type; + + auto Res = + Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)); + if (Res) + return {OL_ERRC_OUT_OF_RESOURCES, "Could not free allocation"}; + + allocInfoMap().erase(Address); + + return OL_SUCCESS; +} + +ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device, + ol_queue_handle_t *Queue) { + auto CreatedQueue = std::make_unique(nullptr, Device); + auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)); + if (Err) + return {OL_ERRC_UNKNOWN, "Could not initialize stream resource"}; + + *Queue = CreatedQueue.release(); + return OL_SUCCESS; +} + +ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue) { + return olDestroy(Queue); +} + +ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue) { + // Host plugin doesn't have a queue set so it's not safe to call synchronize + // on it, but we have nothing to synchronize in that situation anyway. + if (Queue->AsyncInfo->Queue) { + auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo); + if (Err) + return {OL_ERRC_INVALID_QUEUE, "The queue failed to synchronize"}; + } + + // Recreate the stream resource so the queue can be reused + // TODO: Would be easier for the synchronization to (optionally) not release + // it to begin with. + auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo); + if (Res) + return {OL_ERRC_UNKNOWN, "Could not reinitialize the stream resource"}; + + return OL_SUCCESS; +} + +ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event) { + auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo); + if (Res) + return {OL_ERRC_INVALID_EVENT, "The event failed to synchronize"}; + + return OL_SUCCESS; +} + +ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event) { + return olDestroy(Event); +} + +ol_event_handle_t makeEvent(ol_queue_handle_t Queue) { + auto EventImpl = std::make_unique(nullptr, Queue); + auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo); + if (Res) + return nullptr; + + Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo, + Queue->AsyncInfo); + if (Res) + return nullptr; + + return EventImpl.release(); +} + +ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, + ol_device_handle_t DstDevice, void *SrcPtr, + ol_device_handle_t SrcDevice, size_t Size, + ol_event_handle_t *EventOut) { + if (DstDevice == HostDevice() && SrcDevice == HostDevice()) { + if (!Queue) { + std::memcpy(DstPtr, SrcPtr, Size); + return OL_SUCCESS; + } else { + return {OL_ERRC_INVALID_ARGUMENT, + "One of DstDevice and SrcDevice must be a non-host device if " + "Queue is specified"}; + } + } + + // If no queue is given the memcpy will be synchronous + auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr; + + if (DstDevice == HostDevice()) { + auto Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl); + if (Res) + return {OL_ERRC_UNKNOWN, "The data retrieve operation failed"}; + } else if (SrcDevice == HostDevice()) { + auto Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl); + if (Res) + return {OL_ERRC_UNKNOWN, "The data submit operation failed"}; + } else { + auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device, + DstPtr, Size, QueueImpl); + if (Res) + return {OL_ERRC_UNKNOWN, "The data exchange operation failed"}; + } + + if (EventOut) + *EventOut = makeEvent(Queue); + + return OL_SUCCESS; +} + +ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device, + const void *ProgData, size_t ProgDataSize, + ol_program_handle_t *Program) { + // Make a copy of the program binary in case it is released by the caller. + auto ImageData = MemoryBuffer::getMemBufferCopy( + StringRef(reinterpret_cast(ProgData), ProgDataSize)); + + auto DeviceImage = __tgt_device_image{ + const_cast(ImageData->getBuffer().data()), + const_cast(ImageData->getBuffer().data()) + ProgDataSize, nullptr, + nullptr}; + + ol_program_handle_t Prog = + new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage); + + auto Res = + Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage); + if (!Res) { + delete Prog; + return OL_ERRC_INVALID_VALUE; + } + + Prog->Image = *Res; + *Program = Prog; + + return OL_SUCCESS; +} + +ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program) { + return olDestroy(Program); +} + +ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program, + const char *KernelName, + ol_kernel_handle_t *Kernel) { + + auto &Device = Program->Image->getDevice(); + auto KernelImpl = Device.constructKernel(KernelName); + if (!KernelImpl) + return OL_ERRC_INVALID_KERNEL_NAME; + + auto Err = KernelImpl->init(Device, *Program->Image); + if (Err) + return {OL_ERRC_UNKNOWN, "Could not initialize the kernel"}; + + *Kernel = &*KernelImpl; + + return OL_SUCCESS; +} + +ol_impl_result_t +olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, + ol_kernel_handle_t Kernel, const void *ArgumentsData, + size_t ArgumentsSize, + const ol_kernel_launch_size_args_t *LaunchSizeArgs, + ol_event_handle_t *EventOut) { + auto *DeviceImpl = Device->Device; + if (Queue && Device != Queue->Device) { + return {OL_ERRC_INVALID_DEVICE, + "Device specified does not match the device of the given queue"}; + } + + auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr; + AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl); + KernelArgsTy LaunchArgs{}; + LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroupsX; + LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroupsY; + LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroupsZ; + LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSizeX; + LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSizeY; + LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSizeZ; + LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory; + + KernelLaunchParamsTy Params; + Params.Data = const_cast(ArgumentsData); + Params.Size = ArgumentsSize; + LaunchArgs.ArgPtrs = reinterpret_cast(&Params); + // Don't do anything with pointer indirection; use arg data as-is + LaunchArgs.Flags.IsCUDA = true; + + auto *KernelImpl = reinterpret_cast(Kernel); + auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr, + LaunchArgs, AsyncInfoWrapper); + + AsyncInfoWrapper.finalize(Err); + if (Err) + return {OL_ERRC_UNKNOWN, "Could not finalize the AsyncInfoWrapper"}; + + if (EventOut) + *EventOut = makeEvent(Queue); + + return OL_SUCCESS; +} + +} // namespace offload +} // namespace llvm diff --git a/offload/liboffload/src/OffloadLib.cpp b/offload/liboffload/src/OffloadLib.cpp index 70e1ce1f84d8..8662d3a44124 100644 --- a/offload/liboffload/src/OffloadLib.cpp +++ b/offload/liboffload/src/OffloadLib.cpp @@ -11,11 +11,10 @@ //===----------------------------------------------------------------------===// #include "OffloadImpl.hpp" +#include "llvm/Support/raw_ostream.h" #include #include -#include - llvm::StringSet<> &errorStrs() { static llvm::StringSet<> ErrorStrs; return ErrorStrs; @@ -36,9 +35,13 @@ OffloadConfig &offloadConfig() { return Config; } +namespace llvm { +namespace offload { // Pull in the declarations for the implementation functions. The actual entry // points in this file wrap these. #include "OffloadImplFuncDecls.inc" +} // namespace offload +} // namespace llvm // Pull in the tablegen'd entry point definitions. #include "OffloadEntryPoints.inc" diff --git a/offload/test/tools/offload-tblgen/entry_points.td b/offload/test/tools/offload-tblgen/entry_points.td index a66ddb927992..cfddb84aa5b0 100644 --- a/offload/test/tools/offload-tblgen/entry_points.td +++ b/offload/test/tools/offload-tblgen/entry_points.td @@ -20,7 +20,7 @@ def : Function { // The validation function should call the implementation function // CHECK: FunctionA_val -// CHECK: return FunctionA_impl(ParamA, ParamB); +// CHECK: return llvm::offload::FunctionA_impl(ParamA, ParamB); // CHECK: ol_result_t{{.*}} FunctionA( diff --git a/offload/test/tools/offload-tblgen/functions_ranged_param.td b/offload/test/tools/offload-tblgen/functions_ranged_param.td index 21a84d8a7033..d0996b231973 100644 --- a/offload/test/tools/offload-tblgen/functions_ranged_param.td +++ b/offload/test/tools/offload-tblgen/functions_ranged_param.td @@ -25,7 +25,7 @@ def : Function { let returns = []; } -// CHECK: inline std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params) { +// CHECK: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params) { // CHECK: os << ".OutPtr = "; // CHECK: for (size_t i = 0; i < *params->pOutCount; i++) { // CHECK: if (i > 0) { diff --git a/offload/test/tools/offload-tblgen/print_enum.td b/offload/test/tools/offload-tblgen/print_enum.td index 0b5506009bec..97f869689293 100644 --- a/offload/test/tools/offload-tblgen/print_enum.td +++ b/offload/test/tools/offload-tblgen/print_enum.td @@ -15,7 +15,7 @@ def : Enum { ]; } -// CHECK: inline std::ostream &operator<<(std::ostream &os, enum my_enum_t value) +// CHECK: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, enum my_enum_t value) // CHECK: switch (value) { // CHECK: case MY_ENUM_VALUE_ONE: // CHECK: os << "MY_ENUM_VALUE_ONE"; diff --git a/offload/test/tools/offload-tblgen/print_function.td b/offload/test/tools/offload-tblgen/print_function.td index 3f4944df6594..ce1fe4c52760 100644 --- a/offload/test/tools/offload-tblgen/print_function.td +++ b/offload/test/tools/offload-tblgen/print_function.td @@ -27,7 +27,7 @@ def : Function { // CHECK-API-NEXT: ol_foo_handle_t* pParamHandle; // CHECK-API-NEXT: uint32_t** pParamPointer; -// CHECK-PRINT: inline std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params) +// CHECK-PRINT: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params) // CHECK-PRINT: os << ".ParamValue = "; // CHECK-PRINT: os << *params->pParamValue; // CHECK-PRINT: os << ", "; diff --git a/offload/test/tools/offload-tblgen/type_tagged_enum.td b/offload/test/tools/offload-tblgen/type_tagged_enum.td index 49e91e43bb6e..95964e32f0c9 100644 --- a/offload/test/tools/offload-tblgen/type_tagged_enum.td +++ b/offload/test/tools/offload-tblgen/type_tagged_enum.td @@ -50,7 +50,7 @@ def : Function { } // Check that a tagged enum print function definition is generated -// CHECK-PRINT: void printTagged(std::ostream &os, const void *ptr, my_type_tagged_enum_t value, size_t size) { +// CHECK-PRINT: void printTagged(llvm::raw_ostream &os, const void *ptr, my_type_tagged_enum_t value, size_t size) { // CHECK-PRINT: case MY_TYPE_TAGGED_ENUM_VALUE_ONE: { // CHECK-PRINT: const uint32_t * const tptr = (const uint32_t * const)ptr; // CHECK-PRINT: os << (const void *)tptr << " ("; @@ -71,6 +71,6 @@ def : Function { // CHECK-PRINT: } // Check that the tagged type information is used when printing function parameters -// CHECK-PRINT: std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params) { +// CHECK-PRINT: llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params) { // CHECK-PRINT: os << ".PropValue = " // CHECK-PRINT-NEXT: printTagged(os, *params->pPropValue, *params->pPropName, *params->pPropSize); diff --git a/offload/tools/offload-tblgen/APIGen.cpp b/offload/tools/offload-tblgen/APIGen.cpp index 97a2464f7a75..800c9cadfe38 100644 --- a/offload/tools/offload-tblgen/APIGen.cpp +++ b/offload/tools/offload-tblgen/APIGen.cpp @@ -41,9 +41,16 @@ static std::string MakeComment(StringRef in) { } static void ProcessHandle(const HandleRec &H, raw_ostream &OS) { + if (!H.getName().ends_with("_handle_t")) { + errs() << "Handle type name (" << H.getName() + << ") must end with '_handle_t'!\n"; + exit(1); + } + + auto ImplName = H.getName().substr(0, H.getName().size() - 9) + "_impl_t"; OS << CommentsHeader; OS << formatv("/// @brief {0}\n", H.getDesc()); - OS << formatv("typedef struct {0}_ *{0};\n", H.getName()); + OS << formatv("typedef struct {0} *{1};\n", ImplName, H.getName()); } static void ProcessTypedef(const TypedefRec &T, raw_ostream &OS) { @@ -158,6 +165,19 @@ static void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { OS << formatv("} {0};\n\n", Struct.getName()); } +static void ProcessFptrTypedef(const FptrTypedefRec &F, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", F.getDesc()); + OS << formatv("typedef {0} (*{1})(", F.getReturn(), F.getName()); + for (const auto &Param : F.getParams()) { + OS << formatv("\n // {0}\n {1} {2}", Param.getDesc(), Param.getType(), + Param.getName()); + if (Param != F.getParams().back()) + OS << ","; + } + OS << ");\n"; +} + static void ProcessFuncParamStruct(const FunctionRec &Func, raw_ostream &OS) { if (Func.getParams().size() == 0) { return; @@ -213,6 +233,8 @@ void EmitOffloadAPI(const RecordKeeper &Records, raw_ostream &OS) { ProcessEnum(EnumRec{R}, OS); } else if (R->isSubClassOf("Struct")) { ProcessStruct(StructRec{R}, OS); + } else if (R->isSubClassOf("FptrTypedef")) { + ProcessFptrTypedef(FptrTypedefRec{R}, OS); } } diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp index 990ff96a3121..66b9665292e1 100644 --- a/offload/tools/offload-tblgen/EntryPointGen.cpp +++ b/offload/tools/offload-tblgen/EntryPointGen.cpp @@ -35,7 +35,7 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { } OS << ") {\n"; - OS << TAB_1 "if (true /*enableParameterValidation*/) {\n"; + OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n"; // Emit validation checks for (const auto &Return : F.getReturns()) { for (auto &Condition : Return.getConditions()) { @@ -51,7 +51,8 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { // Perform actual function call to the implementation ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); - OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList); + OS << formatv(TAB_1 "return llvm::offload::{0}_impl({1});\n\n", F.getName(), + ParamNameList); OS << "}\n"; } @@ -72,7 +73,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { // Emit pre-call prints OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; - OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName()); + OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName()); OS << TAB_1 "}\n\n"; // Perform actual function call to the validation wrapper @@ -91,13 +92,13 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { } } OS << formatv("};\n"); - OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n"; + OS << TAB_2 "llvm::errs() << \"(\" << &Params << \")\";\n"; } else { - OS << TAB_2 "std::cout << \"()\";\n"; + OS << TAB_2 "llvm::errs() << \"()\";\n"; } - OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n"; + OS << TAB_2 "llvm::errs() << \"-> \" << Result << \"\\n\";\n"; OS << TAB_2 "if (Result && Result->Details) {\n"; - OS << TAB_3 "std::cout << \" *Error Details* \" << Result->Details " + OS << TAB_3 "llvm::errs() << \" *Error Details* \" << Result->Details " "<< \" \\n\";\n"; OS << TAB_2 "}\n"; OS << TAB_1 "}\n"; @@ -121,7 +122,7 @@ static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { OS << "ol_code_location_t *CodeLocation"; OS << ") {\n"; OS << TAB_1 "currentCodeLocation() = CodeLocation;\n"; - OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower, + OS << formatv(TAB_1 "{0}_result_t Result = ::{1}({2});\n\n", PrefixLower, F.getName(), ParamNameList); OS << TAB_1 "currentCodeLocation() = nullptr;\n"; OS << TAB_1 "return Result;\n"; diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp index 2a7c63c3dfd1..a964ff09d0f6 100644 --- a/offload/tools/offload-tblgen/PrintGen.cpp +++ b/offload/tools/offload-tblgen/PrintGen.cpp @@ -20,24 +20,24 @@ using namespace llvm; using namespace offload::tblgen; -constexpr auto PrintEnumHeader = +constexpr auto PrintTypeHeader = R"(/////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the {0} type -/// @returns std::ostream & +/// @returns llvm::raw_ostream & )"; constexpr auto PrintTaggedEnumHeader = R"(/////////////////////////////////////////////////////////////////////////////// /// @brief Print type-tagged {0} enum value -/// @returns std::ostream & +/// @returns llvm::raw_ostream & )"; static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { - OS << formatv(PrintEnumHeader, Enum.getName()); - OS << formatv( - "inline std::ostream &operator<<(std::ostream &os, enum {0} value) " - "{{\n" TAB_1 "switch (value) {{\n", - Enum.getName()); + OS << formatv(PrintTypeHeader, Enum.getName()); + OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " + "enum {0} value) " + "{{\n" TAB_1 "switch (value) {{\n", + Enum.getName()); for (const auto &Val : Enum.getValues()) { auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); @@ -56,7 +56,7 @@ static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { OS << formatv(PrintTaggedEnumHeader, Enum.getName()); OS << formatv(R"""(template <> -inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{ +inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_t size) {{ if (ptr == NULL) {{ printPtr(os, ptr); return; @@ -96,7 +96,7 @@ inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t siz static void EmitResultPrint(raw_ostream &OS) { OS << R""( -inline std::ostream &operator<<(std::ostream &os, +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ol_error_struct_t *Err) { if (Err == nullptr) { os << "OL_SUCCESS"; @@ -115,7 +115,7 @@ static void EmitFunctionParamStructPrint(const FunctionRec &Func, } OS << formatv(R"( -inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{ +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} *params) {{ )", Func.getParamStructName()); @@ -139,6 +139,9 @@ inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{ Param.getName(), TypeInfo->first, TypeInfo->second); } else if (Param.isPointerType() || Param.isHandleType()) { OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName()); + } else if (Param.isFptrType()) { + OS << formatv(TAB_1 "os << reinterpret_cast(*params->p{0});\n", + Param.getName()); } else { OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName()); } @@ -150,6 +153,32 @@ inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{ OS << TAB_1 "return os;\n}\n"; } +void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { + if (Struct.getName() == "ol_error_struct_t") { + return; + } + OS << formatv(PrintTypeHeader, Struct.getName()); + OS << formatv(R"( +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} params) {{ +)", + Struct.getName()); + OS << formatv(TAB_1 "os << \"(struct {0}){{\";\n", Struct.getName()); + for (const auto &Member : Struct.getMembers()) { + OS << formatv(TAB_1 "os << \".{0} = \";\n", Member.getName()); + if (Member.isPointerType() || Member.isHandleType()) { + OS << formatv(TAB_1 "printPtr(os, params.{0});\n", Member.getName()); + } else { + OS << formatv(TAB_1 "os << params.{0};\n", Member.getName()); + } + if (Member.getName() != Struct.getMembers().back().getName()) { + OS << TAB_1 "os << \", \";\n"; + } + } + OS << TAB_1 "os << \"}\";\n"; + OS << TAB_1 "return os;\n"; + OS << "}\n"; +} + void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) { OS << GenericHeader; OS << R"""( @@ -158,11 +187,11 @@ void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) { #pragma once #include -#include +#include -template inline ol_result_t printPtr(std::ostream &os, const T *ptr); -template inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size); +template inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr); +template inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value, size_t size); )"""; // ========== @@ -180,9 +209,9 @@ template inline void printTagged(std::ostream &os, const void *ptr, // use each other. OS << "\n"; for (auto *R : Records.getAllDerivedDefinitions("Enum")) { - OS << formatv( - "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n", - EnumRec{R}.getName()); + OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " + "enum {0} value);\n", + EnumRec{R}.getName()); } OS << "\n"; @@ -193,6 +222,11 @@ template inline void printTagged(std::ostream &os, const void *ptr, } EmitResultPrint(OS); + for (auto *R : Records.getAllDerivedDefinitions("Struct")) { + StructRec S{R}; + ProcessStruct(S, OS); + } + // Emit print functions for the function param structs for (auto *R : Records.getAllDerivedDefinitions("Function")) { EmitFunctionParamStructPrint(FunctionRec{R}, OS); @@ -201,7 +235,7 @@ template inline void printTagged(std::ostream &os, const void *ptr, OS << R"""( /////////////////////////////////////////////////////////////////////////////// // @brief Print pointer value -template inline ol_result_t printPtr(std::ostream &os, const T *ptr) { +template inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) { if (ptr == nullptr) { os << "nullptr"; } else if constexpr (std::is_pointer_v) { diff --git a/offload/tools/offload-tblgen/RecordTypes.hpp b/offload/tools/offload-tblgen/RecordTypes.hpp index 0bf3256c525d..686634ed778a 100644 --- a/offload/tools/offload-tblgen/RecordTypes.hpp +++ b/offload/tools/offload-tblgen/RecordTypes.hpp @@ -103,6 +103,8 @@ public: StringRef getType() const { return rec->getValueAsString("type"); } StringRef getName() const { return rec->getValueAsString("name"); } StringRef getDesc() const { return rec->getValueAsString("desc"); } + bool isPointerType() const { return getType().ends_with('*'); } + bool isHandleType() const { return getType().ends_with("_handle_t"); } private: const Record *rec; @@ -153,6 +155,7 @@ public: StringRef getType() const { return rec->getValueAsString("type"); } bool isPointerType() const { return getType().ends_with('*'); } bool isHandleType() const { return getType().ends_with("_handle_t"); } + bool isFptrType() const { return getType().ends_with("_cb_t"); } StringRef getDesc() const { return rec->getValueAsString("desc"); } bool isIn() const { return dyn_cast(flags->getBit(0))->getValue(); } bool isOut() const { return dyn_cast(flags->getBit(1))->getValue(); } @@ -222,6 +225,23 @@ private: const Record *rec; }; +class FptrTypedefRec { +public: + explicit FptrTypedefRec(const Record *rec) : rec(rec) { + for (auto &Param : rec->getValueAsListOfDefs("params")) + params.emplace_back(Param); + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + StringRef getReturn() const { return rec->getValueAsString("return"); } + const std::vector &getParams() const { return params; } + +private: + std::vector params; + + const Record *rec; +}; + } // namespace tblgen } // namespace offload } // namespace llvm diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt index 033ee2b6ec74..c4d628a5a87f 100644 --- a/offload/unittests/OffloadAPI/CMakeLists.txt +++ b/offload/unittests/OffloadAPI/CMakeLists.txt @@ -1,16 +1,28 @@ set(PLUGINS_TEST_COMMON LLVMOffload) set(PLUGINS_TEST_INCLUDE ${LIBOMPTARGET_INCLUDE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/common) +add_subdirectory(device_code) +message(${OFFLOAD_TEST_DEVICE_CODE_PATH}) + add_libompt_unittest("offload.unittests" ${CMAKE_CURRENT_SOURCE_DIR}/common/Environment.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatform.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformCount.cpp ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformInfo.cpp ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformInfoSize.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDevice.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceCount.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device/olIterateDevices.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfo.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfoSize.cpp) -add_dependencies("offload.unittests" ${PLUGINS_TEST_COMMON}) + ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfoSize.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/queue/olCreateQueue.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/queue/olWaitQueue.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/queue/olDestroyQueue.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemAlloc.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemFree.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemcpy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/program/olCreateProgram.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/program/olDestroyProgram.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/olGetKernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/olLaunchKernel.cpp + ) +add_dependencies("offload.unittests" ${PLUGINS_TEST_COMMON} LibomptUnitTestsDeviceBins) +target_compile_definitions("offload.unittests" PRIVATE DEVICE_CODE_PATH="${OFFLOAD_TEST_DEVICE_CODE_PATH}") target_link_libraries("offload.unittests" PRIVATE ${PLUGINS_TEST_COMMON}) target_include_directories("offload.unittests" PRIVATE ${PLUGINS_TEST_INCLUDE}) diff --git a/offload/unittests/OffloadAPI/common/Environment.cpp b/offload/unittests/OffloadAPI/common/Environment.cpp index f07a66cda218..88cf33e45f3d 100644 --- a/offload/unittests/OffloadAPI/common/Environment.cpp +++ b/offload/unittests/OffloadAPI/common/Environment.cpp @@ -9,7 +9,9 @@ #include "Environment.hpp" #include "Fixtures.hpp" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/MemoryBuffer.h" #include +#include using namespace llvm; @@ -25,8 +27,8 @@ static cl::opt SelectedPlatform("platform", cl::desc("Only test the specified platform"), cl::value_desc("platform")); -std::ostream &operator<<(std::ostream &Out, - const ol_platform_handle_t &Platform) { +raw_ostream &operator<<(raw_ostream &Out, + const ol_platform_handle_t &Platform) { size_t Size; olGetPlatformInfoSize(Platform, OL_PLATFORM_INFO_NAME, &Size); std::vector Name(Size); @@ -35,62 +37,132 @@ std::ostream &operator<<(std::ostream &Out, return Out; } -std::ostream &operator<<(std::ostream &Out, - const std::vector &Platforms) { - for (auto Platform : Platforms) { - Out << "\n * \"" << Platform << "\""; - } - return Out; -} +void printPlatforms() { + SmallDenseSet Platforms; + using DeviceVecT = SmallVector; + DeviceVecT Devices{}; -const std::vector &TestEnvironment::getPlatforms() { - static std::vector Platforms{}; + olIterateDevices( + [](ol_device_handle_t D, void *Data) { + static_cast(Data)->push_back(D); + return true; + }, + &Devices); - if (Platforms.empty()) { - uint32_t PlatformCount = 0; - olGetPlatformCount(&PlatformCount); - if (PlatformCount > 0) { - Platforms.resize(PlatformCount); - olGetPlatform(PlatformCount, Platforms.data()); - } + for (auto &Device : Devices) { + ol_platform_handle_t Platform; + olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), + &Platform); + Platforms.insert(Platform); } - return Platforms; + for (const auto &Platform : Platforms) { + errs() << " * " << Platform << "\n"; + } } -// Get a single platform, which may be selected by the user. -ol_platform_handle_t TestEnvironment::getPlatform() { - static ol_platform_handle_t Platform = nullptr; - const auto &Platforms = getPlatforms(); +ol_device_handle_t TestEnvironment::getDevice() { + static ol_device_handle_t Device = nullptr; - if (!Platform) { + if (!Device) { if (SelectedPlatform != "") { - for (const auto CandidatePlatform : Platforms) { - std::stringstream PlatformName; - PlatformName << CandidatePlatform; - if (SelectedPlatform == PlatformName.str()) { - Platform = CandidatePlatform; - return Platform; - } + olIterateDevices( + [](ol_device_handle_t D, void *Data) { + ol_platform_handle_t Platform; + olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), + &Platform); + + std::string PlatformName; + raw_string_ostream S(PlatformName); + S << Platform; + + if (PlatformName == SelectedPlatform) { + *(static_cast(Data)) = D; + return false; + } + + return true; + }, + &Device); + + if (Device == nullptr) { + errs() << "No device found with the platform \"" << SelectedPlatform + << "\". Choose from:" + << "\n"; + printPlatforms(); + std::exit(1); } - std::cout << "No platform found with the name \"" << SelectedPlatform - << "\". Choose from:" << Platforms << "\n"; - std::exit(1); } else { - // Pick a single platform. We prefer one that has available devices, but - // just pick the first initially in case none have any devices. - Platform = Platforms[0]; - for (auto CandidatePlatform : Platforms) { - uint32_t NumDevices = 0; - if (olGetDeviceCount(CandidatePlatform, &NumDevices) == OL_SUCCESS) { - if (NumDevices > 0) { - Platform = CandidatePlatform; - break; - } - } - } + olIterateDevices( + [](ol_device_handle_t D, void *Data) { + *(static_cast(Data)) = D; + return false; + }, + &Device); } } - return Platform; + return Device; +} + +ol_device_handle_t TestEnvironment::getHostDevice() { + static ol_device_handle_t HostDevice = nullptr; + + if (!HostDevice) { + olIterateDevices( + [](ol_device_handle_t D, void *Data) { + ol_platform_handle_t Platform; + olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), + &Platform); + ol_platform_backend_t Backend; + olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend), + &Backend); + + if (Backend == OL_PLATFORM_BACKEND_HOST) { + *(static_cast(Data)) = D; + return false; + } + + return true; + }, + &HostDevice); + } + + return HostDevice; +} + +// TODO: Allow overriding via cmd line arg +const std::string DeviceBinsDirectory = DEVICE_CODE_PATH; + +bool TestEnvironment::loadDeviceBinary( + const std::string &BinaryName, ol_device_handle_t Device, + std::unique_ptr &BinaryOut) { + + // Get the platform type + ol_platform_handle_t Platform; + olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), &Platform); + ol_platform_backend_t Backend = OL_PLATFORM_BACKEND_UNKNOWN; + olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend), + &Backend); + std::string FileExtension; + if (Backend == OL_PLATFORM_BACKEND_AMDGPU) { + FileExtension = ".amdgpu.bin"; + } else if (Backend == OL_PLATFORM_BACKEND_CUDA) { + FileExtension = ".nvptx64.bin"; + } else { + errs() << "Unsupported platform type for a device binary test.\n"; + return false; + } + + std::string SourcePath = + DeviceBinsDirectory + "/" + BinaryName + FileExtension; + + auto SourceFile = MemoryBuffer::getFile(SourcePath, false, false); + if (!SourceFile) { + errs() << "failed to read device binary file: " + SourcePath; + return false; + } + + BinaryOut = std::move(SourceFile.get()); + return true; } diff --git a/offload/unittests/OffloadAPI/common/Environment.hpp b/offload/unittests/OffloadAPI/common/Environment.hpp index 6dba2381eb0b..a0bf688b4551 100644 --- a/offload/unittests/OffloadAPI/common/Environment.hpp +++ b/offload/unittests/OffloadAPI/common/Environment.hpp @@ -8,10 +8,13 @@ #pragma once +#include "llvm/Support/MemoryBuffer.h" #include #include namespace TestEnvironment { -const std::vector &getPlatforms(); -ol_platform_handle_t getPlatform(); +ol_device_handle_t getDevice(); +ol_device_handle_t getHostDevice(); +bool loadDeviceBinary(const std::string &BinaryName, ol_device_handle_t Device, + std::unique_ptr &BinaryOut); } // namespace TestEnvironment diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp index 410a435dee1b..028ebf43d5cd 100644 --- a/offload/unittests/OffloadAPI/common/Fixtures.hpp +++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp @@ -27,6 +27,14 @@ } while (0) #endif +#ifndef ASSERT_ANY_ERROR +#define ASSERT_ANY_ERROR(ACTUAL) \ + do { \ + ol_result_t Res = ACTUAL; \ + ASSERT_TRUE(Res); \ + } while (0) +#endif + #define RETURN_ON_FATAL_FAILURE(...) \ __VA_ARGS__; \ if (this->HasFatalFailure() || this->IsSkipped()) { \ @@ -34,31 +42,81 @@ } \ (void)0 -struct offloadTest : ::testing::Test { - // No special behavior now, but just in case we need to override it in future +struct OffloadTest : ::testing::Test { + ol_device_handle_t Host = TestEnvironment::getHostDevice(); }; -struct offloadPlatformTest : offloadTest { +struct OffloadDeviceTest : OffloadTest { void SetUp() override { - RETURN_ON_FATAL_FAILURE(offloadTest::SetUp()); + RETURN_ON_FATAL_FAILURE(OffloadTest::SetUp()); - Platform = TestEnvironment::getPlatform(); + Device = TestEnvironment::getDevice(); + if (Device == nullptr) + GTEST_SKIP() << "No available devices."; + } + + ol_device_handle_t Device = nullptr; +}; + +struct OffloadPlatformTest : OffloadDeviceTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); + + ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, + sizeof(Platform), &Platform)); ASSERT_NE(Platform, nullptr); } - ol_platform_handle_t Platform; + ol_platform_handle_t Platform = nullptr; }; -struct offloadDeviceTest : offloadPlatformTest { +// Fixture for a generic program test. If you want a different program, use +// offloadQueueTest and create your own program handle with the binary you want. +struct OffloadProgramTest : OffloadDeviceTest { void SetUp() override { - RETURN_ON_FATAL_FAILURE(offloadPlatformTest::SetUp()); - - uint32_t NumDevices; - ASSERT_SUCCESS(olGetDeviceCount(Platform, &NumDevices)); - if (NumDevices == 0) - GTEST_SKIP() << "No available devices on this platform."; - ASSERT_SUCCESS(olGetDevice(Platform, 1, &Device)); + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &Program)); } - ol_device_handle_t Device; + void TearDown() override { + if (Program) { + olDestroyProgram(Program); + } + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::TearDown()); + } + + ol_program_handle_t Program = nullptr; + std::unique_ptr DeviceBin; +}; + +struct OffloadKernelTest : OffloadProgramTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUp()); + ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); + } + + void TearDown() override { + RETURN_ON_FATAL_FAILURE(OffloadProgramTest::TearDown()); + } + + ol_kernel_handle_t Kernel = nullptr; +}; + +struct OffloadQueueTest : OffloadDeviceTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); + ASSERT_SUCCESS(olCreateQueue(Device, &Queue)); + } + + void TearDown() override { + if (Queue) { + olDestroyQueue(Queue); + } + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::TearDown()); + } + + ol_queue_handle_t Queue = nullptr; }; diff --git a/offload/unittests/OffloadAPI/device/olGetDevice.cpp b/offload/unittests/OffloadAPI/device/olGetDevice.cpp deleted file mode 100644 index 68d4682dd335..000000000000 --- a/offload/unittests/OffloadAPI/device/olGetDevice.cpp +++ /dev/null @@ -1,39 +0,0 @@ -//===------- Offload API tests - olGetDevice -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "../common/Fixtures.hpp" -#include -#include - -using olGetDeviceTest = offloadPlatformTest; - -TEST_F(olGetDeviceTest, Success) { - uint32_t Count = 0; - ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count)); - if (Count == 0) - GTEST_SKIP() << "No available devices on this platform."; - - std::vector Devices(Count); - ASSERT_SUCCESS(olGetDevice(Platform, Count, Devices.data())); - for (auto Device : Devices) { - ASSERT_NE(nullptr, Device); - } -} - -TEST_F(olGetDeviceTest, SuccessSubsetOfDevices) { - uint32_t Count; - ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count)); - if (Count < 2) - GTEST_SKIP() << "Only one device is available on this platform."; - - std::vector Devices(Count - 1); - ASSERT_SUCCESS(olGetDevice(Platform, Count - 1, Devices.data())); - for (auto Device : Devices) { - ASSERT_NE(nullptr, Device); - } -} diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp deleted file mode 100644 index ef377d671bf6..000000000000 --- a/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp +++ /dev/null @@ -1,28 +0,0 @@ -//===------- Offload API tests - olGetDeviceCount --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "../common/Fixtures.hpp" -#include -#include - -using olGetDeviceCountTest = offloadPlatformTest; - -TEST_F(olGetDeviceCountTest, Success) { - uint32_t Count = 0; - ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count)); -} - -TEST_F(olGetDeviceCountTest, InvalidNullPlatform) { - uint32_t Count = 0; - ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olGetDeviceCount(nullptr, &Count)); -} - -TEST_F(olGetDeviceCountTest, InvalidNullPointer) { - ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, - olGetDeviceCount(Platform, nullptr)); -} diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp index c936802fb1e4..f71f60a2c057 100644 --- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp +++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp @@ -11,10 +11,10 @@ #include #include -struct olGetDeviceInfoTest : offloadDeviceTest, +struct olGetDeviceInfoTest : OffloadDeviceTest, ::testing::WithParamInterface { - void SetUp() override { RETURN_ON_FATAL_FAILURE(offloadDeviceTest::SetUp()); } + void SetUp() override { RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); } }; INSTANTIATE_TEST_SUITE_P( @@ -37,7 +37,7 @@ TEST_P(olGetDeviceInfoTest, Success) { if (InfoType == OL_DEVICE_INFO_PLATFORM) { auto *ReturnedPlatform = reinterpret_cast(InfoData.data()); - ASSERT_EQ(Platform, *ReturnedPlatform); + ASSERT_NE(nullptr, *ReturnedPlatform); } } diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp index 9e792d1c3e25..b4b5042dbfd8 100644 --- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp +++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp @@ -12,10 +12,10 @@ #include "olDeviceInfo.hpp" struct olGetDeviceInfoSizeTest - : offloadDeviceTest, + : OffloadDeviceTest, ::testing::WithParamInterface { - void SetUp() override { RETURN_ON_FATAL_FAILURE(offloadDeviceTest::SetUp()); } + void SetUp() override { RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); } }; // TODO: We could autogenerate the list of enum values diff --git a/offload/unittests/OffloadAPI/device/olIterateDevices.cpp b/offload/unittests/OffloadAPI/device/olIterateDevices.cpp new file mode 100644 index 000000000000..5bdbd17e9e97 --- /dev/null +++ b/offload/unittests/OffloadAPI/device/olIterateDevices.cpp @@ -0,0 +1,45 @@ +//===------- Offload API tests - olIterateDevices -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olIterateDevicesTest = OffloadTest; + +TEST_F(olIterateDevicesTest, SuccessEmptyCallback) { + ASSERT_SUCCESS(olIterateDevices( + [](ol_device_handle_t, void *) { return false; }, nullptr)); +} + +TEST_F(olIterateDevicesTest, SuccessGetDevice) { + uint32_t DeviceCount = 0; + ol_device_handle_t Device = nullptr; + + ASSERT_SUCCESS(olIterateDevices( + [](ol_device_handle_t, void *Data) { + auto Count = static_cast(Data); + *Count += 1; + return false; + }, + &DeviceCount)); + + if (DeviceCount == 0) { + GTEST_SKIP() << "No available devices."; + } + + ASSERT_SUCCESS(olIterateDevices( + [](ol_device_handle_t D, void *Data) { + auto DevicePtr = static_cast(Data); + *DevicePtr = D; + return true; + }, + &Device)); + + ASSERT_NE(Device, nullptr); +} diff --git a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt new file mode 100644 index 000000000000..ded555b3a3cf --- /dev/null +++ b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt @@ -0,0 +1,67 @@ +macro(add_offload_test_device_code test_filename test_name) + set(SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${test_filename}) + + # Build for NVPTX + if(OFFLOAD_TEST_TARGET_NVIDIA) + set(BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/${test_name}.nvptx64.bin) + add_custom_command(OUTPUT ${BIN_PATH} + COMMAND + ${CMAKE_C_COMPILER} --target=nvptx64-nvidia-cuda + -march=${LIBOMPTARGET_DEP_CUDA_ARCH} + --cuda-path=${CUDA_ROOT} + ${SRC_PATH} -o ${BIN_PATH} + DEPENDS ${SRC_PATH} + ) + list(APPEND BIN_PATHS ${BIN_PATH}) + endif() + + # Build for AMDGPU + if(OFFLOAD_TEST_TARGET_AMDGPU) + set(BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/${test_name}.amdgpu.bin) + add_custom_command(OUTPUT ${BIN_PATH} + COMMAND + ${CMAKE_C_COMPILER} --target=amdgcn-amd-amdhsa -nogpulib + -mcpu=${LIBOMPTARGET_DEP_AMDGPU_ARCH} + ${SRC_PATH} -o ${BIN_PATH} + DEPENDS ${SRC_PATH} + ) + list(APPEND BIN_PATHS ${BIN_PATH}) + endif() + + # TODO: Build for host CPU +endmacro() + + +# Decide what device targets to build for. LibomptargetGetDependencies is +# included at the top-level so the GPUs present on the system are already +# detected. +set(OFFLOAD_TESTS_FORCE_NVIDIA_ARCH "" CACHE STRING + "Force building of NVPTX device code for Offload unit tests with the given arch, e.g. sm_61") +set(OFFLOAD_TESTS_FORCE_AMDGPU_ARCH "" CACHE STRING + "Force building of AMDGPU device code for Offload unit tests with the given arch, e.g. gfx1030") + +find_package(CUDAToolkit QUIET) +if(CUDAToolkit_FOUND) + get_filename_component(CUDA_ROOT "${CUDAToolkit_BIN_DIR}" DIRECTORY ABSOLUTE) +endif() +if (OFFLOAD_TESTS_FORCE_NVIDIA_ARCH) + set(LIBOMPTARGET_DEP_CUDA_ARCH ${OFFLOAD_TESTS_FORCE_NVIDIA_ARCH}) + set(OFFLOAD_TEST_TARGET_NVIDIA ON) +elseif (LIBOMPTARGET_FOUND_NVIDIA_GPU AND CUDA_ROOT AND "cuda" IN_LIST LIBOMPTARGET_PLUGINS_TO_BUILD) + set(OFFLOAD_TEST_TARGET_NVIDIA ON) +endif() + +if (OFFLOAD_TESTS_FORCE_AMDGPU_ARCH) + set(LIBOMPTARGET_DEP_AMDGPU_ARCH ${OFFLOAD_TESTS_FORCE_AMDGPU_ARCH}) + set(OFFLOAD_TEST_TARGET_AMDGPU ON) +elseif (LIBOMPTARGET_FOUND_AMDGPU_GPU AND "amdgpu" IN_LIST LIBOMPTARGET_PLUGINS_TO_BUILD) + list(GET LIBOMPTARGET_AMDGPU_DETECTED_ARCH_LIST 0 LIBOMPTARGET_DEP_AMDGPU_ARCH) + set(OFFLOAD_TEST_TARGET_AMDGPU ON) +endif() + +add_offload_test_device_code(foo.c foo) +add_offload_test_device_code(bar.c bar) + +add_custom_target(LibomptUnitTestsDeviceBins DEPENDS ${BIN_PATHS}) + +set(OFFLOAD_TEST_DEVICE_CODE_PATH ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) diff --git a/offload/unittests/OffloadAPI/device_code/bar.c b/offload/unittests/OffloadAPI/device_code/bar.c new file mode 100644 index 000000000000..786aa2f5d61e --- /dev/null +++ b/offload/unittests/OffloadAPI/device_code/bar.c @@ -0,0 +1,5 @@ +#include + +__gpu_kernel void foo(int *out) { + out[__gpu_thread_id(0)] = __gpu_thread_id(0) + 1; +} diff --git a/offload/unittests/OffloadAPI/device_code/foo.c b/offload/unittests/OffloadAPI/device_code/foo.c new file mode 100644 index 000000000000..5bc893961d49 --- /dev/null +++ b/offload/unittests/OffloadAPI/device_code/foo.c @@ -0,0 +1,5 @@ +#include + +__gpu_kernel void foo(int *out) { + out[__gpu_thread_id(0)] = __gpu_thread_id(0); +} diff --git a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp new file mode 100644 index 000000000000..f320d191ad58 --- /dev/null +++ b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp @@ -0,0 +1,30 @@ +//===------- Offload API tests - olGetKernel ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olGetKernelTest = OffloadProgramTest; + +TEST_F(olGetKernelTest, Success) { + ol_kernel_handle_t Kernel = nullptr; + ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); + ASSERT_NE(Kernel, nullptr); +} + +TEST_F(olGetKernelTest, InvalidNullProgram) { + ol_kernel_handle_t Kernel = nullptr; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olGetKernel(nullptr, "foo", &Kernel)); +} + +TEST_F(olGetKernelTest, InvalidNullKernelPointer) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olGetKernel(Program, "foo", nullptr)); +} diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp new file mode 100644 index 000000000000..2e51a48b9a7a --- /dev/null +++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp @@ -0,0 +1,83 @@ +//===------- Offload API tests - olLaunchKernel --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +struct olLaunchKernelTest : OffloadQueueTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp()); + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &Program)); + ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); + LaunchArgs.Dimensions = 1; + LaunchArgs.GroupSizeX = 64; + LaunchArgs.GroupSizeY = 1; + LaunchArgs.GroupSizeZ = 1; + + LaunchArgs.NumGroupsX = 1; + LaunchArgs.NumGroupsY = 1; + LaunchArgs.NumGroupsZ = 1; + + LaunchArgs.DynSharedMemory = 0; + } + + void TearDown() override { + if (Program) { + olDestroyProgram(Program); + } + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown()); + } + + std::unique_ptr DeviceBin; + ol_program_handle_t Program = nullptr; + ol_kernel_handle_t Kernel = nullptr; + ol_kernel_launch_size_args_t LaunchArgs{}; +}; + +TEST_F(olLaunchKernelTest, Success) { + void *Mem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 64, &Mem)); + struct { + void *Mem; + } Args{Mem}; + + ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), + &LaunchArgs, nullptr)); + + ASSERT_SUCCESS(olWaitQueue(Queue)); + + int *Data = (int *)Mem; + for (int i = 0; i < 64; i++) { + ASSERT_EQ(Data[i], i); + } + + ASSERT_SUCCESS(olMemFree(Mem)); +} + +TEST_F(olLaunchKernelTest, SuccessSynchronous) { + void *Mem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 64, &Mem)); + + struct { + void *Mem; + } Args{Mem}; + + ASSERT_SUCCESS(olLaunchKernel(nullptr, Device, Kernel, &Args, sizeof(Args), + &LaunchArgs, nullptr)); + + int *Data = (int *)Mem; + for (int i = 0; i < 64; i++) { + ASSERT_EQ(Data[i], i); + } + + ASSERT_SUCCESS(olMemFree(Mem)); +} diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp new file mode 100644 index 000000000000..580ba022954e --- /dev/null +++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp @@ -0,0 +1,45 @@ +//===------- Offload API tests - olMemAlloc -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olMemAllocTest = OffloadDeviceTest; + +TEST_F(olMemAllocTest, SuccessAllocManaged) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); + ASSERT_NE(Alloc, nullptr); + olMemFree(Alloc); +} + +TEST_F(olMemAllocTest, SuccessAllocHost) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); + ASSERT_NE(Alloc, nullptr); + olMemFree(Alloc); +} + +TEST_F(olMemAllocTest, SuccessAllocDevice) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); + ASSERT_NE(Alloc, nullptr); + olMemFree(Alloc); +} + +TEST_F(olMemAllocTest, InvalidNullDevice) { + void *Alloc = nullptr; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olMemAlloc(nullptr, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); +} + +TEST_F(olMemAllocTest, InvalidNullOutPtr) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, nullptr)); +} diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp new file mode 100644 index 000000000000..99ad389f27fb --- /dev/null +++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp @@ -0,0 +1,38 @@ +//===------- Offload API tests - olMemFree --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olMemFreeTest = OffloadDeviceTest; + +TEST_F(olMemFreeTest, SuccessFreeManaged) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); + ASSERT_SUCCESS(olMemFree(Alloc)); +} + +TEST_F(olMemFreeTest, SuccessFreeHost) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); + ASSERT_SUCCESS(olMemFree(Alloc)); +} + +TEST_F(olMemFreeTest, SuccessFreeDevice) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); + ASSERT_SUCCESS(olMemFree(Alloc)); +} + +TEST_F(olMemFreeTest, InvalidNullPtr) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr)); + ASSERT_SUCCESS(olMemFree(Alloc)); +} diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp new file mode 100644 index 000000000000..b00ded9b53ed --- /dev/null +++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp @@ -0,0 +1,106 @@ +//===------- Offload API tests - olMemcpy --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olMemcpyTest = OffloadQueueTest; + +TEST_F(olMemcpyTest, SuccessHtoD) { + constexpr size_t Size = 1024; + void *Alloc; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc)); + std::vector Input(Size, 42); + ASSERT_SUCCESS( + olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size, nullptr)); + olWaitQueue(Queue); + olMemFree(Alloc); +} + +TEST_F(olMemcpyTest, SuccessDtoH) { + constexpr size_t Size = 1024; + void *Alloc; + std::vector Input(Size, 42); + std::vector Output(Size, 0); + + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc)); + ASSERT_SUCCESS( + olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size, nullptr)); + ASSERT_SUCCESS( + olMemcpy(Queue, Output.data(), Host, Alloc, Device, Size, nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + for (uint8_t Val : Output) { + ASSERT_EQ(Val, 42); + } + ASSERT_SUCCESS(olMemFree(Alloc)); +} + +TEST_F(olMemcpyTest, SuccessDtoD) { + constexpr size_t Size = 1024; + void *AllocA; + void *AllocB; + std::vector Input(Size, 42); + std::vector Output(Size, 0); + + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &AllocA)); + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &AllocB)); + ASSERT_SUCCESS( + olMemcpy(Queue, AllocA, Device, Input.data(), Host, Size, nullptr)); + ASSERT_SUCCESS( + olMemcpy(Queue, AllocB, Device, AllocA, Device, Size, nullptr)); + ASSERT_SUCCESS( + olMemcpy(Queue, Output.data(), Host, AllocB, Device, Size, nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + for (uint8_t Val : Output) { + ASSERT_EQ(Val, 42); + } + ASSERT_SUCCESS(olMemFree(AllocA)); + ASSERT_SUCCESS(olMemFree(AllocB)); +} + +TEST_F(olMemcpyTest, SuccessHtoHSync) { + constexpr size_t Size = 1024; + std::vector Input(Size, 42); + std::vector Output(Size, 0); + + ASSERT_SUCCESS(olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, + Size, nullptr)); + + for (uint8_t Val : Output) { + ASSERT_EQ(Val, 42); + } +} + +TEST_F(olMemcpyTest, SuccessDtoHSync) { + constexpr size_t Size = 1024; + void *Alloc; + std::vector Input(Size, 42); + std::vector Output(Size, 0); + + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc)); + ASSERT_SUCCESS( + olMemcpy(nullptr, Alloc, Device, Input.data(), Host, Size, nullptr)); + ASSERT_SUCCESS( + olMemcpy(nullptr, Output.data(), Host, Alloc, Device, Size, nullptr)); + for (uint8_t Val : Output) { + ASSERT_EQ(Val, 42); + } + ASSERT_SUCCESS(olMemFree(Alloc)); +} + +TEST_F(olMemcpyTest, SuccessSizeZero) { + constexpr size_t Size = 1024; + std::vector Input(Size, 42); + std::vector Output(Size, 0); + + // As with std::memcpy, size 0 is allowed. Keep all other arguments valid even + // if they aren't used. + ASSERT_SUCCESS( + olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr)); +} diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp deleted file mode 100644 index 4a2f9e8ac774..000000000000 --- a/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp +++ /dev/null @@ -1,28 +0,0 @@ -//===------- Offload API tests - olGetPlatform -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "../common/Fixtures.hpp" -#include -#include - -using olGetPlatformTest = offloadTest; - -TEST_F(olGetPlatformTest, Success) { - uint32_t PlatformCount; - ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount)); - std::vector Platforms(PlatformCount); - ASSERT_SUCCESS(olGetPlatform(PlatformCount, Platforms.data())); -} - -TEST_F(olGetPlatformTest, InvalidNumEntries) { - uint32_t PlatformCount; - ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount)); - std::vector Platforms(PlatformCount); - ASSERT_ERROR(OL_ERRC_INVALID_SIZE, - olGetPlatform(PlatformCount + 1, Platforms.data())); -} diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp index c646bdc50b7d..bd6ad3f84e77 100644 --- a/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp +++ b/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp @@ -12,7 +12,7 @@ #include "olPlatformInfo.hpp" struct olGetPlatformInfoTest - : offloadPlatformTest, + : OffloadPlatformTest, ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P( diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp index 7c9274082e8e..5f6067e2e259 100644 --- a/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp +++ b/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp @@ -12,7 +12,7 @@ #include "olPlatformInfo.hpp" struct olGetPlatformInfoSizeTest - : offloadPlatformTest, + : OffloadPlatformTest, ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P( diff --git a/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp b/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp index d49cdb90d321..f61bca0cf52f 100644 --- a/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp +++ b/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include // TODO: We could autogenerate these diff --git a/offload/unittests/OffloadAPI/program/olCreateProgram.cpp b/offload/unittests/OffloadAPI/program/olCreateProgram.cpp new file mode 100644 index 000000000000..c586c0459620 --- /dev/null +++ b/offload/unittests/OffloadAPI/program/olCreateProgram.cpp @@ -0,0 +1,27 @@ +//===------- Offload API tests - olCreateProgram --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olCreateProgramTest = OffloadDeviceTest; + +TEST_F(olCreateProgramTest, Success) { + + std::unique_ptr DeviceBin; + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + + ol_program_handle_t Program; + ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &Program)); + ASSERT_NE(Program, nullptr); + + ASSERT_SUCCESS(olDestroyProgram(Program)); +} diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp b/offload/unittests/OffloadAPI/program/olDestroyProgram.cpp similarity index 50% rename from offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp rename to offload/unittests/OffloadAPI/program/olDestroyProgram.cpp index 15b4b6abcd70..ea21dadb59c4 100644 --- a/offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp +++ b/offload/unittests/OffloadAPI/program/olDestroyProgram.cpp @@ -1,4 +1,4 @@ -//===------- Offload API tests - olGetPlatformCount ------------------===// +//===------- Offload API tests - olDestroyProgram -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,13 +10,13 @@ #include #include -using olGetPlatformCountTest = offloadTest; +using olDestroyProgramTest = OffloadProgramTest; -TEST_F(olGetPlatformCountTest, Success) { - uint32_t PlatformCount; - ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount)); +TEST_F(olDestroyProgramTest, Success) { + ASSERT_SUCCESS(olDestroyProgram(Program)); + Program = nullptr; } -TEST_F(olGetPlatformCountTest, InvalidNullPointer) { - ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olGetPlatformCount(nullptr)); +TEST_F(olDestroyProgramTest, InvalidNullHandle) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olDestroyProgram(nullptr)); } diff --git a/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp b/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp new file mode 100644 index 000000000000..0534debed055 --- /dev/null +++ b/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp @@ -0,0 +1,28 @@ +//===------- Offload API tests - olCreateQueue ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olCreateQueueTest = OffloadDeviceTest; + +TEST_F(olCreateQueueTest, Success) { + ol_queue_handle_t Queue = nullptr; + ASSERT_SUCCESS(olCreateQueue(Device, &Queue)); + ASSERT_NE(Queue, nullptr); +} + +TEST_F(olCreateQueueTest, InvalidNullHandleDevice) { + ol_queue_handle_t Queue = nullptr; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olCreateQueue(nullptr, &Queue)); +} + +TEST_F(olCreateQueueTest, InvalidNullPointerQueue) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olCreateQueue(Device, nullptr)); +} diff --git a/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp b/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp new file mode 100644 index 000000000000..b54694e0c798 --- /dev/null +++ b/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp @@ -0,0 +1,22 @@ +//===------- Offload API tests - olDestroyQueue ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olDestroyQueueTest = OffloadQueueTest; + +TEST_F(olDestroyQueueTest, Success) { + ASSERT_SUCCESS(olDestroyQueue(Queue)); + Queue = nullptr; +} + +TEST_F(olDestroyQueueTest, InvalidNullHandle) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olDestroyQueue(nullptr)); +} diff --git a/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp new file mode 100644 index 000000000000..07ef774583ae --- /dev/null +++ b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp @@ -0,0 +1,17 @@ +//===------- Offload API tests - olWaitQueue ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olWaitQueueTest = OffloadQueueTest; + +TEST_F(olWaitQueueTest, SuccessEmptyQueue) { + ASSERT_SUCCESS(olWaitQueue(Queue)); +}