| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- // Copyright © 2023 Apple Inc.
- #pragma once
- #include <ATen/mps/MPSStream.h>
- #include <ctime>
- #include <stack>
- namespace at::mps {
- // NOTE: don't create instances of this class directly.
- // Use MPSEventPool to acquire instances of MPSEvent.
- class MPSEvent {
- public:
- explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
- ~MPSEvent();
- // records an event on the stream
- void record(bool needsLock, bool syncEvent = false);
- // makes all future work submitted to the stream wait for this event.
- bool wait(bool needsLock, bool syncEvent = false);
- // schedules a notifyListener callback for the event.
- bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
- // checks if events are already signaled.
- bool query() const;
- // blocks the CPU thread until all the GPU work that were scheduled
- // prior to recording this event are completed.
- bool synchronize();
- // resets this event with new parameters in case it gets reused from the event pool
- void reset(MPSStream* stream, bool enable_timing);
- // returns the unique ID of the event instance
- id_t getID() const { return m_id; }
- // returns the completion timestamp of the event
- uint64_t getCompletionTime() const { return m_completion_time; }
- // if already recorded, waits for cpu_sync_cv to be signaled
- void waitForCpuSync();
- private:
- id_t m_id;
- // enables measuring the completion time of the notifyListener of this event
- bool m_enable_timing;
- uint64_t m_signalCounter = 0;
- MPSStream* m_stream = nullptr;
- MTLSharedEvent_t m_event = nullptr;
- MTLSharedEventListener* m_listener = nullptr;
- // used to sync the events created on this Stream with CPU
- std::mutex m_cpu_sync_mutex{};
- std::condition_variable m_cpu_sync_cv{};
- // CondVar predicate to sync the events created on this Stream with CPU
- bool m_cpu_sync_completed = false;
- // used to compute elapsed time
- uint64_t m_completion_time = 0;
- void recordLocked(bool syncEvent);
- bool waitLocked(bool syncEvent);
- bool notifyLocked(MTLSharedEventNotificationBlock block);
- void notifyCpuSync();
- static uint64_t getTime() {
- return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
- }
- };
- typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
- class MPSEventPool {
- public:
- explicit MPSEventPool(MPSStream* default_stream);
- ~MPSEventPool();
- MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
- void emptyCache();
- // these are mainly used for MPSHooks and torch.mps.Event() bindings
- id_t acquireEvent(bool enable_timing);
- void releaseEvent(id_t event_id);
- void recordEvent(id_t event_id, bool syncEvent);
- void waitForEvent(id_t event_id, bool syncEvent);
- void synchronizeEvent(id_t event_id);
- bool queryEvent(id_t event_id);
- // returns elapsed time between two recorded events in milliseconds
- double elapsedTime(id_t start_event_id, id_t end_event_id);
- private:
- MPSStream* m_default_stream = nullptr;
- std::recursive_mutex m_mutex;
- std::stack<std::unique_ptr<MPSEvent>> m_pool{};
- // dictionary to associate event IDs with event objects
- // used to retain in-use events out of the pool
- // for torch.mps.Event() bindings.
- std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
- uint64_t m_event_counter = 0;
- std::function<void(MPSEvent*)> m_default_deleter;
- MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
- };
- // shared_ptr is used to get MPSEventPool destroyed after dependent instances
- std::shared_ptr<MPSEventPool> getMPSEventPool();
- } // namespace at::mps
|