event.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # mypy: allow-untyped-defs
  2. import torch
  3. class Event:
  4. r"""Wrapper around an MPS event.
  5. MPS events are synchronization markers that can be used to monitor the
  6. device's progress, to accurately measure timing, and to synchronize MPS streams.
  7. Args:
  8. enable_timing (bool, optional): indicates if the event should measure time
  9. (default: ``False``)
  10. """
  11. def __init__(self, enable_timing=False):
  12. self.__eventId = torch._C._mps_acquireEvent(enable_timing)
  13. def __del__(self):
  14. # checks if torch._C is already destroyed
  15. if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
  16. torch._C._mps_releaseEvent(self.__eventId)
  17. def record(self):
  18. r"""Records the event in the default stream."""
  19. torch._C._mps_recordEvent(self.__eventId)
  20. def wait(self):
  21. r"""Makes all future work submitted to the default stream wait for this event."""
  22. torch._C._mps_waitForEvent(self.__eventId)
  23. def query(self):
  24. r"""Returns True if all work currently captured by event has completed."""
  25. return torch._C._mps_queryEvent(self.__eventId)
  26. def synchronize(self):
  27. r"""Waits until the completion of all work currently captured in this event.
  28. This prevents the CPU thread from proceeding until the event completes.
  29. """
  30. torch._C._mps_synchronizeEvent(self.__eventId)
  31. def elapsed_time(self, end_event):
  32. r"""Returns the time elapsed in milliseconds after the event was
  33. recorded and before the end_event was recorded.
  34. """
  35. return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)