| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- # mypy: allow-untyped-defs
- import torch
- class Event:
- r"""Wrapper around an MPS event.
- MPS events are synchronization markers that can be used to monitor the
- device's progress, to accurately measure timing, and to synchronize MPS streams.
- Args:
- enable_timing (bool, optional): indicates if the event should measure time
- (default: ``False``)
- """
- def __init__(self, enable_timing=False):
- self.__eventId = torch._C._mps_acquireEvent(enable_timing)
- def __del__(self):
- # checks if torch._C is already destroyed
- if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
- torch._C._mps_releaseEvent(self.__eventId)
- def record(self):
- r"""Records the event in the default stream."""
- torch._C._mps_recordEvent(self.__eventId)
- def wait(self):
- r"""Makes all future work submitted to the default stream wait for this event."""
- torch._C._mps_waitForEvent(self.__eventId)
- def query(self):
- r"""Returns True if all work currently captured by event has completed."""
- return torch._C._mps_queryEvent(self.__eventId)
- def synchronize(self):
- r"""Waits until the completion of all work currently captured in this event.
- This prevents the CPU thread from proceeding until the event completes.
- """
- torch._C._mps_synchronizeEvent(self.__eventId)
- def elapsed_time(self, end_event):
- r"""Returns the time elapsed in milliseconds after the event was
- recorded and before the end_event was recorded.
- """
- return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)
|