_gpu_trace.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import Callable
  2. from torch._utils import CallbackRegistry
  3. EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  4. "CUDA event creation"
  5. )
  6. EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  7. "CUDA event deletion"
  8. )
  9. EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  10. "CUDA event record"
  11. )
  12. EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  13. "CUDA event wait"
  14. )
  15. MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  16. "CUDA memory allocation"
  17. )
  18. MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  19. "CUDA memory deallocation"
  20. )
  21. StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  22. "CUDA stream creation"
  23. )
  24. DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
  25. "CUDA device synchronization"
  26. )
  27. StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  28. "CUDA stream synchronization"
  29. )
  30. EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  31. "CUDA event synchronization"
  32. )
  33. def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
  34. EventCreationCallbacks.add_callback(cb)
  35. def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
  36. EventDeletionCallbacks.add_callback(cb)
  37. def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
  38. EventRecordCallbacks.add_callback(cb)
  39. def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
  40. EventWaitCallbacks.add_callback(cb)
  41. def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
  42. MemoryAllocationCallbacks.add_callback(cb)
  43. def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
  44. MemoryDeallocationCallbacks.add_callback(cb)
  45. def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
  46. StreamCreationCallbacks.add_callback(cb)
  47. def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
  48. DeviceSynchronizationCallbacks.add_callback(cb)
  49. def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
  50. StreamSynchronizationCallbacks.add_callback(cb)
  51. def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
  52. EventSynchronizationCallbacks.add_callback(cb)