diff --git a/offload/include/Shared/RefCnt.h b/offload/include/Shared/RefCnt.h index 7c615ba167a3..5031a6ff9024 100644 --- a/offload/include/Shared/RefCnt.h +++ b/offload/include/Shared/RefCnt.h @@ -31,16 +31,16 @@ struct RefCountTy { ~RefCountTy() { assert(Refs == 0 && "Destroying with non-zero refcount"); } - /// Increase the reference count atomically. - void increase() { Refs.fetch_add(1, MemoryOrder); } + /// Increase the reference count atomically by \p Amount. + void increase(Ty Amount = 1) { Refs.fetch_add(Amount, MemoryOrder); } - /// Decrease the reference count and return whether it became zero. Decreasing - /// the counter in more units than it was previously increased results in - /// undefined behavior. - bool decrease() { - Ty Prev = Refs.fetch_sub(1, MemoryOrder); - assert(Prev > 0 && "Invalid refcount"); - return (Prev == 1); + /// Decrease the reference count by \p Amount and return whether it became + /// zero. Decreasing the counter by more than it was previously increased + /// results in undefined behavior. + bool decrease(Ty Amount = 1) { + Ty Prev = Refs.fetch_sub(Amount, MemoryOrder); + assert(Prev >= Amount && "Invalid refcount"); + return (Prev == Amount); } Ty get() const { return Refs.load(MemoryOrder); } diff --git a/offload/liboffload/API/Event.td b/offload/liboffload/API/Event.td index 075bf5bafaa6..be77500562a1 100644 --- a/offload/liboffload/API/Event.td +++ b/offload/liboffload/API/Event.td @@ -13,7 +13,8 @@ def olCreateEvent : Function { let desc = "Enqueue an event to `Queue` and return it."; let details = [ - "This event can be used with `olSyncEvent` and `olWaitEvents` and will be complete once all enqueued work prior to the `olCreateEvent` call is complete.", + "This event can be used with `olSyncEvent`, `olWaitEvents`, and `olGetEventElapsedTime`.", + "It will be complete once all enqueued work prior to the `olCreateEvent` call is complete.", ]; let params = [ Param<"ol_queue_handle_t", "Queue", "queue to create the event for", PARAM_IN>, @@ -40,6 +41,20 @@ def olSyncEvent : Function { let returns = []; } +def olGetEventElapsedTime : Function { + let desc = "Get the elapsed time in milliseconds between two events."; + let details = [ + "The elapsed time is returned in milliseconds.", + "The queues associated with `StartEvent` and `EndEvent` must belong to the same device." + ]; + let params = [ + Param<"ol_event_handle_t", "StartEvent", "handle of the start event", PARAM_IN>, + Param<"ol_event_handle_t", "EndEvent", "handle of the end event", PARAM_IN>, + Param<"float*", "ElapsedTime", "output pointer for the elapsed time in milliseconds", PARAM_OUT> + ]; + let returns = []; +} + def ol_event_info_t : Enum { let desc = "Supported event info."; let is_typed = 1; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index dd3ec0f61b4d..77933e6291f4 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -175,8 +175,8 @@ struct ol_event_impl_t { ol_queue_handle_t Queue) : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) { } - // EventInfo may be null, in which case the event should be considered always - // complete + // Opaque backend-specific event state. This is expected to be non-null for + // backends that materialize real events. void *EventInfo; ol_device_handle_t Device; size_t QueueId; @@ -794,7 +794,8 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, return Plugin::error(ErrorCode::INVALID_NULL_HANDLE, "olWaitEvents asked to wait on a NULL event"); - // Do nothing if the event is for this queue or the event is always complete + // Do nothing if the event is for this queue or the backend does not + // materialize event state for it. if (Event->QueueId == Queue->Id || !Event->EventInfo) continue; @@ -839,7 +840,8 @@ Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, } Error olSyncEvent_impl(ol_event_handle_t Event) { - // No event info means that this event was complete on creation + // Some backends do not materialize backend event state. Treat such events as + // trivially complete. if (!Event->EventInfo) return Plugin::success(); @@ -849,6 +851,23 @@ Error olSyncEvent_impl(ol_event_handle_t Event) { return Error::success(); } +Error olGetEventElapsedTime_impl(ol_event_handle_t StartEvent, + ol_event_handle_t EndEvent, + float *ElapsedTime) { + if (StartEvent->Device != EndEvent->Device) + return createOffloadError( + ErrorCode::INVALID_DEVICE, + "StartEvent and EndEvent must belong to the same device"); + + auto ElapsedTimeOrErr = StartEvent->Device->Device->getEventElapsedTime( + StartEvent->EventInfo, EndEvent->EventInfo); + if (!ElapsedTimeOrErr) + return ElapsedTimeOrErr.takeError(); + + *ElapsedTime = *ElapsedTimeOrErr; + return Error::success(); +} + Error olDestroyEvent_impl(ol_event_handle_t Event) { if (Event->EventInfo) if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo)) @@ -867,7 +886,8 @@ Error olGetEventInfoImplDetail(ol_event_handle_t Event, case OL_EVENT_INFO_QUEUE: return Info.write(Queue); case OL_EVENT_INFO_IS_COMPLETE: { - // No event info means that this event was complete on creation + // Some backends do not materialize backend event state. Treat such events + // as trivially complete. if (!Event->EventInfo) return Info.write(true); @@ -898,24 +918,24 @@ Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName, } Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { - auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); - if (auto Err = Pending.takeError()) + auto Event = std::make_unique(nullptr, Queue->Device, Queue); + + if (auto Err = Queue->Device->Device->createEvent(&Event->EventInfo)) return Err; - *EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue); - if (!*Pending) - // Queue is empty, don't record an event and consider the event always - // complete - return Plugin::success(); + if (auto Err = Queue->Device->Device->recordEvent(Event->EventInfo, + Queue->AsyncInfo)) { + if (Event->EventInfo) { + if (auto DestroyErr = + Queue->Device->Device->destroyEvent(Event->EventInfo)) + return joinErrors(std::move(Err), std::move(DestroyErr)); + } - if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo)) - return Res; + return Err; + } - if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo, - Queue->AsyncInfo)) - return Res; - - return Plugin::success(); + *EventOut = Event.release(); + return Error::success(); } Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, diff --git a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.cpp b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.cpp index 37d12861eb38..279a296dd161 100644 --- a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.cpp +++ b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.cpp @@ -70,6 +70,8 @@ DLWRAP(hsa_amd_register_system_event_handler, 2) DLWRAP(hsa_amd_signal_create, 5) DLWRAP(hsa_amd_signal_async_handler, 5) DLWRAP(hsa_amd_pointer_info, 5) +DLWRAP(hsa_amd_profiling_get_dispatch_time, 3) +DLWRAP(hsa_amd_profiling_set_profiler_enabled, 2) DLWRAP(hsa_code_object_reader_create_from_memory, 3) DLWRAP(hsa_code_object_reader_destroy, 1) DLWRAP(hsa_executable_load_agent_code_object, 5) diff --git a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.h b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.h index ad135f72fff1..f6e3337ddb3f 100644 --- a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.h +++ b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa.h @@ -99,6 +99,8 @@ typedef enum { typedef enum { HSA_SYSTEM_INFO_VERSION_MAJOR = 0, HSA_SYSTEM_INFO_VERSION_MINOR = 1, + HSA_SYSTEM_INFO_TIMESTAMP = 2, + HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY = 3, } hsa_system_info_t; typedef enum { diff --git a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa_ext_amd.h b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa_ext_amd.h index ddfa65c76cf2..7ff77f8e2a2f 100644 --- a/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa_ext_amd.h +++ b/offload/plugins-nextgen/amdgpu/dynamic_hsa/hsa_ext_amd.h @@ -169,6 +169,18 @@ hsa_status_t hsa_amd_pointer_info(const void* ptr, uint32_t* num_agents_accessible, hsa_agent_t** accessible); +typedef struct hsa_amd_profiling_dispatch_time_s { + uint64_t start; + uint64_t end; +} hsa_amd_profiling_dispatch_time_t; + +hsa_status_t +hsa_amd_profiling_get_dispatch_time(hsa_agent_t agent, hsa_signal_t signal, + hsa_amd_profiling_dispatch_time_t *time); + +hsa_status_t hsa_amd_profiling_set_profiler_enabled(hsa_queue_t *queue, + int enable); + #ifdef __cplusplus } #endif diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 83bb4262e669..af0d09ebeb57 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -690,11 +690,14 @@ struct AMDGPUSignalTy { /// plugin thread or the HSA runtime. void reset() { hsa_signal_store_screlease(HSASignal, 1); } - /// Increase the number of concurrent uses. - void increaseUseCount() { UseCount.increase(); } + /// Increase the number of concurrent uses by \p Amount. + void increaseUseCount(uint32_t Amount = 1) { UseCount.increase(Amount); } - /// Decrease the number of concurrent uses and return whether was the last. - bool decreaseUseCount() { return UseCount.decrease(); } + /// Decrease the number of concurrent uses by \p Amount and return whether it + /// became zero. + bool decreaseUseCount(uint32_t Amount = 1) { + return UseCount.decrease(Amount); + } hsa_signal_t get() const { return HSASignal; } @@ -704,7 +707,7 @@ private: /// Reference counter for tracking the concurrent use count. This is mainly /// used for knowing how many streams are using the signal. - RefCountTy<> UseCount; + RefCountTy UseCount; }; /// Classes for holding AMDGPU signals and managing signals. @@ -720,10 +723,22 @@ struct AMDGPUQueueTy { Error init(GenericDeviceTy &Device, hsa_agent_t Agent, int32_t QueueSize) { if (Queue) return Plugin::success(); + hsa_status_t Status = hsa_queue_create(Agent, QueueSize, HSA_QUEUE_TYPE_MULTI, callbackError, &Device, UINT32_MAX, UINT32_MAX, &Queue); - return Plugin::check(Status, "error in hsa_queue_create: %s"); + if (auto Err = Plugin::check(Status, "error in hsa_queue_create: %s")) + return Err; + + // Enable queue profiling from creation time onward, as HIP/ROCclr does. + // Elapsed-time queries rely on queue-level hardware profiling support to + // retrieve packet timing. + Status = hsa_amd_profiling_set_profiler_enabled(Queue, 1); + if (auto Err = Plugin::check( + Status, "error in hsa_amd_profiling_set_profiler_enabled: %s")) + return Err; + + return Plugin::success(); } /// Deinitialize the queue and destroy its resources. @@ -1144,6 +1159,18 @@ private: return {Curr, InputSignal}; } + /// Roll back the last consumed slot after a submission failure so the stream + /// does not retain a slot for an operation that was never enqueued. + void rollbackConsumedSlot(uint32_t Slot) { + assert(NextSlot > 0 && "Cannot roll back an empty stream"); + assert(Slot + 1 == NextSlot && "Can only roll back the last consumed slot"); + + Slots[Slot].Signal = nullptr; + Slots[Slot].Callbacks.clear(); + Slots[Slot].ActionArgs.clear(); + --NextSlot; + } + /// Complete all pending post actions and reset the stream after synchronizing /// or positively querying the stream. Error complete() { @@ -1643,8 +1670,9 @@ public: const AMDGPUQueueTy *getQueue() const { return Queue; } - /// Record the state of the stream on an event. - Error recordEvent(AMDGPUEventTy &Event) const; + /// Record an event by enqueuing a barrier marker packet on the stream. + Error recordEvent(AMDGPUEventTy &Event, + AMDGPUSignalTy *ReusedSignal = nullptr); /// Make the stream wait on an event. Error waitEvent(const AMDGPUEventTy &Event); @@ -1652,25 +1680,46 @@ public: friend struct AMDGPUStreamManagerTy; }; -/// Class representing an event on AMDGPU. The event basically stores some -/// information regarding the state of the recorded stream. +/// Class representing an event on AMDGPU. The event stores the recorded stream +/// point and retained timing state. struct AMDGPUEventTy { /// Create an empty event. AMDGPUEventTy(AMDGPUDeviceTy &Device) - : RecordedStream(nullptr), RecordedSlot(-1), RecordedSyncCycle(-1) {} + : Device(Device), RecordedStream(nullptr), RecordedSlot(-1), + RecordedSyncCycle(-1), TimingSignal(nullptr) {} /// Initialize and deinitialize. - Error init() { return Plugin::success(); } - Error deinit() { return Plugin::success(); } + Error init() { return resetState(); } + Error deinit() { return resetState(); } - /// Record the state of a stream on the event. + /// Clear the current recording and retained timing state, optionally + /// returning a reusable timing signal. + Error resetState(AMDGPUSignalTy **ReusableSignalPtr = nullptr) { + RecordedStream = nullptr; + RecordedSlot = -1; + RecordedSyncCycle = -1; + return releaseTimingSignal(ReusableSignalPtr); + } + + /// Record the current stream point on the event. Error record(AMDGPUStreamTy &Stream) { std::lock_guard Lock(Mutex); - // Ignore the last recorded stream. + // Discard the previous recording and retained timing state, reusing the + // retained timing signal if it becomes available. + AMDGPUSignalTy *Signal = nullptr; + if (auto Err = resetState(&Signal)) + return Err; + RecordedStream = &Stream; - return Stream.recordEvent(*this); + if (auto Err = Stream.recordEvent(*this, Signal)) { + if (auto ResetErr = resetState()) + return joinErrors(std::move(Err), std::move(ResetErr)); + return Err; + } + + return Plugin::success(); } /// Make a stream wait on the current event. @@ -1708,38 +1757,79 @@ struct AMDGPUEventTy { return RecordedStream->synchronizeOn(*this); } + /// Return the elapsed time in milliseconds between this event and EndEvent. + Expected getElapsedTime(AMDGPUEventTy &EndEvent); + protected: + /// Release the retained timing signal, if any, either back to the signal + /// manager or through \p ReusableSignalPtr when provided. + Error releaseTimingSignal(AMDGPUSignalTy **ReusableSignalPtr = nullptr); + + /// The device that owns this event. + AMDGPUDeviceTy &Device; + /// The stream registered in this event. AMDGPUStreamTy *RecordedStream; - /// The recordered operation on the recorded stream. + /// The recorded operation on the recorded stream. int64_t RecordedSlot; /// The sync cycle when the stream was recorded. Used to detect stale events. int64_t RecordedSyncCycle; + /// The signal of the recorded timing barrier. + AMDGPUSignalTy *TimingSignal; + /// Mutex to safely access event fields. mutable std::mutex Mutex; friend struct AMDGPUStreamTy; }; -Error AMDGPUStreamTy::recordEvent(AMDGPUEventTy &Event) const { - std::lock_guard Lock(Mutex); +Error AMDGPUStreamTy::recordEvent(AMDGPUEventTy &Event, + AMDGPUSignalTy *ReusedSignal) { + if (Queue == nullptr) + return Plugin::error(ErrorCode::INVALID_NULL_POINTER, + "target queue was nullptr"); - if (size() > 0) { - // Record the synchronize identifier (to detect stale recordings) and - // the last valid stream's operation. - Event.RecordedSyncCycle = SyncCycle; - Event.RecordedSlot = last(); + // One use for the stream slot and one for the event timing signal. + const uint32_t OutputSignalUses = 2; - assert(Event.RecordedSyncCycle >= 0 && "Invalid recorded sync cycle"); - assert(Event.RecordedSlot >= 0 && "Invalid recorded slot"); - } else { - // The stream is empty, everything already completed, record nothing. - Event.RecordedSyncCycle = -1; - Event.RecordedSlot = -1; + // Reuse the provided signal or retrieve one for the operation's output. + AMDGPUSignalTy *OutputSignal = ReusedSignal; + if (!OutputSignal) { + if (auto Err = SignalManager.getResource(OutputSignal)) + return Err; } + + OutputSignal->reset(); + OutputSignal->increaseUseCount(OutputSignalUses); + + std::lock_guard StreamLock(Mutex); + + // Consume stream slot and compute dependencies. + auto [Curr, InputSignal] = consume(OutputSignal); + + // Materialize the event as a real marker on the queue. Elapsed-time queries + // need a packet-backed completion signal to retrieve dispatch timing. + if (auto Err = Queue->pushBarrier(OutputSignal, InputSignal, nullptr)) { + rollbackConsumedSlot(Curr); + + if (OutputSignal->decreaseUseCount(OutputSignalUses)) { + if (auto ReturnErr = SignalManager.returnResource(OutputSignal)) + return joinErrors(std::move(Err), std::move(ReturnErr)); + } + + return Err; + } + + Event.RecordedSlot = Curr; + Event.RecordedSyncCycle = SyncCycle; + Event.TimingSignal = OutputSignal; + + assert(Event.RecordedSyncCycle >= 0 && "Invalid recorded sync cycle"); + assert(Event.RecordedSlot >= 0 && "Invalid recorded slot"); + return Plugin::success(); } @@ -2124,6 +2214,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { ClockFrequency) != HSA_STATUS_SUCCESS) ClockFrequency = 0; + // Retrieve the HSA system timestamp frequency for this runtime. A zero + // value means the frequency is unavailable. + if (hsa_system_get_info(HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, + &SystemTimestampFrequency) != HSA_STATUS_SUCCESS) + SystemTimestampFrequency = 0; + // Load the grid values depending on the wavefront. if (WavefrontSize == 32) GridValues = getAMDGPUGridValues<32>(); @@ -2333,6 +2429,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { /// Returns the clock frequency for the given AMDGPU device. uint64_t getClockFrequency() const override { return ClockFrequency; } + /// Returns the HSA system timestamp frequency. Zero means unavailable. + uint64_t getSystemTimestampFrequency() const { + return SystemTimestampFrequency; + } + /// Allocate and construct an AMDGPU kernel. Expected constructKernel(const char *Name) override { // Allocate and construct the AMDGPU kernel. @@ -2813,12 +2914,19 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { /// Create an event. Error createEventImpl(void **EventPtrStorage) override { AMDGPUEventTy **Event = reinterpret_cast(EventPtrStorage); - return AMDGPUEventManager.getResource(*Event); + if (auto Err = AMDGPUEventManager.getResource(*Event)) + return Err; + return (*Event)->resetState(); } /// Destroy a previously created event. Error destroyEventImpl(void *EventPtr) override { AMDGPUEventTy *Event = reinterpret_cast(EventPtr); + assert(Event && "Invalid event"); + + if (auto Err = Event->resetState()) + return Err; + return AMDGPUEventManager.returnResource(Event); } @@ -2871,6 +2979,19 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return Event->sync(); } + /// Get the elapsed time in milliseconds between two events. + Expected getEventElapsedTimeImpl(void *StartEventPtr, + void *EndEventPtr) override { + AMDGPUEventTy *StartEvent = + reinterpret_cast(StartEventPtr); + AMDGPUEventTy *EndEvent = reinterpret_cast(EndEventPtr); + + if (!StartEvent || !EndEvent) + return Plugin::error(ErrorCode::INVALID_ARGUMENT, "invalid event handle"); + + return StartEvent->getElapsedTime(*EndEvent); + } + /// Print information about the device. Expected obtainInfoImpl() override { char TmpChar[1000]; @@ -3347,6 +3468,10 @@ private: /// The frequency of the steady clock inside the device. uint64_t ClockFrequency; + /// The HSA system timestamp frequency reported by the runtime. Zero means + /// unavailable. + uint64_t SystemTimestampFrequency = 0; + /// The total number of concurrent work items that can be running on the GPU. uint64_t HardwareParallelism; @@ -3453,6 +3578,83 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device) StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()), UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {} +Error AMDGPUEventTy::releaseTimingSignal(AMDGPUSignalTy **ReusableSignalPtr) { + AMDGPUSignalTy *Signal = TimingSignal; + TimingSignal = nullptr; + + if (!Signal) + return Plugin::success(); + + if (!Signal->decreaseUseCount()) + return Plugin::success(); + + if (ReusableSignalPtr) { + *ReusableSignalPtr = Signal; + return Plugin::success(); + } + + return Device.getSignalManager().returnResource(Signal); +} + +Expected AMDGPUEventTy::getElapsedTime(AMDGPUEventTy &EndEvent) { + if (this == &EndEvent) { + std::lock_guard Lock(Mutex); + + if (!TimingSignal) + return Plugin::error(ErrorCode::INVALID_ARGUMENT, + "event does not have a recorded timing signal"); + + if (TimingSignal->load()) + return Plugin::error(ErrorCode::UNKNOWN, "event timing is not ready"); + + return 0.0f; + } + + const uint64_t TicksPerSecond = Device.getSystemTimestampFrequency(); + if (TicksPerSecond == 0) + return Plugin::error(ErrorCode::UNSUPPORTED, + "HSA system timestamp frequency is unavailable"); + + std::scoped_lock Lock(Mutex, EndEvent.Mutex); + + if (&Device != &EndEvent.Device) + return Plugin::error(ErrorCode::INVALID_ARGUMENT, + "events belong to different devices"); + + if (!TimingSignal || !EndEvent.TimingSignal) + return Plugin::error( + ErrorCode::INVALID_ARGUMENT, + "one or both events do not have a recorded timing signal"); + + if (TimingSignal->load() || EndEvent.TimingSignal->load()) + return Plugin::error( + ErrorCode::UNKNOWN, + "timing information is not ready for one or both events"); + + hsa_amd_profiling_dispatch_time_t StartTime = {}; + hsa_amd_profiling_dispatch_time_t StopTime = {}; + + hsa_status_t Status = hsa_amd_profiling_get_dispatch_time( + Device.getAgent(), TimingSignal->get(), &StartTime); + if (auto Err = Plugin::check( + Status, "error in hsa_amd_profiling_get_dispatch_time: %s")) + return std::move(Err); + + Status = hsa_amd_profiling_get_dispatch_time( + EndEvent.Device.getAgent(), EndEvent.TimingSignal->get(), &StopTime); + if (auto Err = Plugin::check( + Status, "error in hsa_amd_profiling_get_dispatch_time: %s")) + return std::move(Err); + + const int64_t DeltaTicks = + static_cast(StopTime.end) - static_cast(StartTime.end); + constexpr double MillisecondsPerSecond = 1000.0; + + return static_cast(static_cast(DeltaTicks) * + MillisecondsPerSecond / + static_cast(TicksPerSecond)); +} + /// Class implementing the AMDGPU-specific functionalities of the global /// handler. struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy { diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index 20f8279dfee2..7990b09d59c6 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -994,6 +994,11 @@ struct GenericDeviceTy : public DeviceAllocatorTy { Error syncEvent(void *EventPtr); virtual Error syncEventImpl(void *EventPtr) = 0; + /// Get the elapsed time in milliseconds between two events. + Expected getEventElapsedTime(void *StartEventPtr, void *EndEventPtr); + virtual Expected getEventElapsedTimeImpl(void *StartEventPtr, + void *EndEventPtr) = 0; + /// Obtain information about the device. Expected obtainInfo(); virtual Expected obtainInfoImpl() = 0; @@ -1552,6 +1557,10 @@ public: /// Synchronize execution until an event is done. int32_t sync_event(int32_t DeviceId, void *EventPtr); + /// Get the elapsed time in milliseconds between two events. + int32_t get_event_elapsed_time(int32_t DeviceId, void *StartEventPtr, + void *EndEventPtr, float *ElapsedTime); + /// Remove the event from the plugin. int32_t destroy_event(int32_t DeviceId, void *EventPtr); diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 8a32e4177d3d..3420678cac98 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1568,6 +1568,11 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) { return syncEventImpl(EventPtr); } +Expected GenericDeviceTy::getEventElapsedTime(void *StartEventPtr, + void *EndEventPtr) { + return getEventElapsedTimeImpl(StartEventPtr, EndEventPtr); +} + bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); } Expected GenericDeviceTy::isAccessiblePtr(const void *Ptr, size_t Size) { @@ -2087,6 +2092,23 @@ int32_t GenericPluginTy::sync_event(int32_t DeviceId, void *EventPtr) { return OFFLOAD_SUCCESS; } +int32_t GenericPluginTy::get_event_elapsed_time(int32_t DeviceId, + void *StartEventPtr, + void *EndEventPtr, + float *ElapsedTime) { + auto ElapsedTimeOrErr = + getDevice(DeviceId).getEventElapsedTime(StartEventPtr, EndEventPtr); + if (!ElapsedTimeOrErr) { + REPORT() << "Failure to get elapsed time between events " << StartEventPtr + << " and " << EndEventPtr << ": " + << toString(ElapsedTimeOrErr.takeError()); + return OFFLOAD_FAIL; + } + + *ElapsedTime = *ElapsedTimeOrErr; + return OFFLOAD_SUCCESS; +} + int32_t GenericPluginTy::destroy_event(int32_t DeviceId, void *EventPtr) { auto Err = getDevice(DeviceId).destroyEvent(EventPtr); if (Err) { diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp index 80e3e418ae3f..8fc8d0e43fab 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp @@ -97,6 +97,7 @@ DLWRAP(cuEventRecord, 2) DLWRAP(cuEventQuery, 1) DLWRAP(cuStreamWaitEvent, 3) DLWRAP(cuEventSynchronize, 1) +DLWRAP(cuEventElapsedTime, 3) DLWRAP(cuEventDestroy, 1) DLWRAP_FINALIZE() diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h index fa4f4634ecec..dd47fb98dc03 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h @@ -371,6 +371,7 @@ CUresult cuEventRecord(CUevent, CUstream); CUresult cuEventQuery(CUevent); CUresult cuStreamWaitEvent(CUstream, CUevent, unsigned int); CUresult cuEventSynchronize(CUevent); +CUresult cuEventElapsedTime(float *, CUevent, CUevent); CUresult cuEventDestroy(CUevent); CUresult cuMemUnmap(CUdeviceptr ptr, size_t size); diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 4de754265ea7..7a47f2ce7e5a 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -1086,6 +1086,20 @@ struct CUDADeviceTy : public GenericDeviceTy { return Plugin::check(Res, "error in cuEventSynchronize: %s"); } + /// Get the elapsed time in milliseconds between two events. + Expected getEventElapsedTimeImpl(void *StartEventPtr, + void *EndEventPtr) override { + CUevent StartEvent = reinterpret_cast(StartEventPtr); + CUevent EndEvent = reinterpret_cast(EndEventPtr); + + float ElapsedTime = 0.0f; + CUresult Res = cuEventElapsedTime(&ElapsedTime, StartEvent, EndEvent); + if (auto Err = Plugin::check(Res, "error in cuEventElapsedTime: %s")) + return std::move(Err); + + return ElapsedTime; + } + /// Print information about the device. Expected obtainInfoImpl() override { char TmpChar[1000]; diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp index 077dd14b959e..bef49faf4538 100644 --- a/offload/plugins-nextgen/host/src/rtl.cpp +++ b/offload/plugins-nextgen/host/src/rtl.cpp @@ -360,6 +360,10 @@ struct GenELF64DeviceTy : public GenericDeviceTy { return true; } Error syncEventImpl(void *EventPtr) override { return Plugin::success(); } + Expected getEventElapsedTimeImpl(void *StartEventPtr, + void *EndEventPtr) override { + return 0.0f; + } /// Print information about the device. Expected obtainInfoImpl() override { diff --git a/offload/plugins-nextgen/level_zero/include/L0Device.h b/offload/plugins-nextgen/level_zero/include/L0Device.h index a06fabd2d407..8c83d14f8d1b 100644 --- a/offload/plugins-nextgen/level_zero/include/L0Device.h +++ b/offload/plugins-nextgen/level_zero/include/L0Device.h @@ -630,6 +630,12 @@ public: __func__); } + Expected getEventElapsedTimeImpl(void *StartEventPtr, + void *EndEventPtr) override { + return Plugin::error(error::ErrorCode::UNKNOWN, "%s not implemented yet\n", + __func__); + } + Expected obtainInfoImpl() override; uint64_t getClockFrequency() const override { return getClockRate(); } uint64_t getHardwareParallelism() const override { return getTotalThreads(); } diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt index 031dbea660fb..39863391f27d 100644 --- a/offload/unittests/OffloadAPI/CMakeLists.txt +++ b/offload/unittests/OffloadAPI/CMakeLists.txt @@ -13,6 +13,7 @@ add_offload_unittest("event" event/olCreateEvent.cpp event/olDestroyEvent.cpp event/olSyncEvent.cpp + event/olGetEventElapsedTime.cpp event/olGetEventInfo.cpp event/olGetEventInfoSize.cpp) diff --git a/offload/unittests/OffloadAPI/event/olGetEventElapsedTime.cpp b/offload/unittests/OffloadAPI/event/olGetEventElapsedTime.cpp new file mode 100644 index 000000000000..aca2dccff72f --- /dev/null +++ b/offload/unittests/OffloadAPI/event/olGetEventElapsedTime.cpp @@ -0,0 +1,146 @@ +//===------- Offload API tests - olGetEventElapsedTime --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include "llvm/Support/MemoryBuffer.h" +#include +#include + +namespace { + +struct olGetEventElapsedTimeTest : OffloadQueueTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp()); + + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &Program)); + ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel)); + + LaunchArgs.Dimensions = 1; + LaunchArgs.GroupSize = {64, 1, 1}; + LaunchArgs.NumGroups = {1, 1, 1}; + LaunchArgs.DynSharedMemory = 0; + + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem)); + } + + void TearDown() override { + if (Mem) + ASSERT_SUCCESS(olMemFree(Mem)); + if (Program) + ASSERT_SUCCESS(olDestroyProgram(Program)); + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown()); + } + + void launchFoo() { + struct { + void *Mem; + } Args{Mem}; + + ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), + &LaunchArgs)); + } + + std::unique_ptr DeviceBin; + ol_program_handle_t Program = nullptr; + ol_symbol_handle_t Kernel = nullptr; + ol_kernel_launch_size_args_t LaunchArgs{}; + void *Mem = nullptr; +}; + +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetEventElapsedTimeTest); + +TEST_P(olGetEventElapsedTimeTest, Success) { + ol_event_handle_t StartEvent = nullptr; + ol_event_handle_t EndEvent = nullptr; + + ASSERT_SUCCESS(olCreateEvent(Queue, &StartEvent)); + ASSERT_NE(StartEvent, nullptr); + + launchFoo(); + + ASSERT_SUCCESS(olCreateEvent(Queue, &EndEvent)); + ASSERT_NE(EndEvent, nullptr); + + ASSERT_SUCCESS(olSyncEvent(EndEvent)); + + float ElapsedTime = -1.0f; + + ASSERT_SUCCESS(olGetEventElapsedTime(StartEvent, EndEvent, &ElapsedTime)); + ASSERT_GE(ElapsedTime, 0.0f); + + ASSERT_SUCCESS(olDestroyEvent(StartEvent)); + ASSERT_SUCCESS(olDestroyEvent(EndEvent)); +} + +TEST_P(olGetEventElapsedTimeTest, SuccessMultipleCalls) { + ol_event_handle_t StartEvent = nullptr; + ol_event_handle_t EndEvent = nullptr; + + ASSERT_SUCCESS(olCreateEvent(Queue, &StartEvent)); + ASSERT_NE(StartEvent, nullptr); + + launchFoo(); + + ASSERT_SUCCESS(olCreateEvent(Queue, &EndEvent)); + ASSERT_NE(EndEvent, nullptr); + + ASSERT_SUCCESS(olSyncEvent(EndEvent)); + + float ElapsedTimeA = -1.0f; + float ElapsedTimeB = -1.0f; + + ASSERT_SUCCESS(olGetEventElapsedTime(StartEvent, EndEvent, &ElapsedTimeA)); + ASSERT_SUCCESS(olGetEventElapsedTime(StartEvent, EndEvent, &ElapsedTimeB)); + + ASSERT_GE(ElapsedTimeA, 0.0f); + ASSERT_GE(ElapsedTimeB, 0.0f); + + ASSERT_SUCCESS(olDestroyEvent(StartEvent)); + ASSERT_SUCCESS(olDestroyEvent(EndEvent)); +} + +TEST_P(olGetEventElapsedTimeTest, InvalidNullStartEvent) { + ol_event_handle_t EndEvent = nullptr; + ASSERT_SUCCESS(olCreateEvent(Queue, &EndEvent)); + + float ElapsedTime = 0.0f; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olGetEventElapsedTime(nullptr, EndEvent, &ElapsedTime)); + + ASSERT_SUCCESS(olDestroyEvent(EndEvent)); +} + +TEST_P(olGetEventElapsedTimeTest, InvalidNullEndEvent) { + ol_event_handle_t StartEvent = nullptr; + ASSERT_SUCCESS(olCreateEvent(Queue, &StartEvent)); + + float ElapsedTime = 0.0f; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olGetEventElapsedTime(StartEvent, nullptr, &ElapsedTime)); + + ASSERT_SUCCESS(olDestroyEvent(StartEvent)); +} + +TEST_P(olGetEventElapsedTimeTest, InvalidNullElapsedTime) { + ol_event_handle_t StartEvent = nullptr; + ol_event_handle_t EndEvent = nullptr; + + ASSERT_SUCCESS(olCreateEvent(Queue, &StartEvent)); + ASSERT_SUCCESS(olCreateEvent(Queue, &EndEvent)); + + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olGetEventElapsedTime(StartEvent, EndEvent, nullptr)); + + ASSERT_SUCCESS(olDestroyEvent(StartEvent)); + ASSERT_SUCCESS(olDestroyEvent(EndEvent)); +} + +} // namespace