compile_time_profiler.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # mypy: disallow-untyped-defs
  2. import logging
  3. import os
  4. from datetime import datetime
  5. from socket import gethostname
  6. from typing import Any, Optional
  7. from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler
  8. logger = logging.getLogger("strobelight_compile_time_profiler")
  9. console_handler = logging.StreamHandler()
  10. formatter = logging.Formatter(
  11. "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
  12. )
  13. console_handler.setFormatter(formatter)
  14. logger.addHandler(console_handler)
  15. logger.setLevel(logging.INFO)
  16. logger.propagate = False
  17. class StrobelightCompileTimeProfiler:
  18. success_profile_count: int = 0
  19. failed_profile_count: int = 0
  20. ignored_profile_runs: int = 0
  21. inside_profile_compile_time: bool = False
  22. enabled: bool = False
  23. # A unique identifier that is used as the run_user_name in the strobelight profile to
  24. # associate all compile time profiles together.
  25. identifier: Optional[str] = None
  26. current_phase: Optional[str] = None
  27. profiler: Optional[Any] = None
  28. max_stack_length: int = int(
  29. os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 127)
  30. )
  31. max_profile_time: int = int(
  32. os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30)
  33. )
  34. # Collect sample each x cycles.
  35. sample_each: int = int(
  36. float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
  37. )
  38. @classmethod
  39. def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
  40. if cls.enabled:
  41. logger.info("compile time strobelight profiling already enabled")
  42. return
  43. logger.info("compile time strobelight profiling enabled")
  44. if profiler_class is StrobelightCLIFunctionProfiler:
  45. import shutil
  46. if not shutil.which("strobeclient"):
  47. logger.info(
  48. "strobeclient not found, cant enable compile time strobelight profiling, seems"
  49. "like you are not on a FB machine."
  50. )
  51. return
  52. cls.enabled = True
  53. cls._cls_init()
  54. # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler.
  55. # we have pass different functionProfilerClass for meta-internal fbcode targets.
  56. cls.profiler = profiler_class(
  57. sample_each=cls.sample_each,
  58. max_profile_duration_sec=cls.max_profile_time,
  59. stack_max_len=cls.max_stack_length,
  60. async_stack_max_len=cls.max_stack_length,
  61. run_user_name="pt2-profiler/"
  62. + os.environ.get("USER", os.environ.get("USERNAME", "")),
  63. sample_tags={cls.identifier},
  64. )
  65. @classmethod
  66. def _cls_init(cls) -> None:
  67. cls.identifier = "{date}{pid}{hostname}".format(
  68. date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"),
  69. pid=os.getpid(),
  70. hostname=gethostname(),
  71. )
  72. logger.info("Unique sample tag for this run is: %s", cls.identifier)
  73. logger.info(
  74. "You can use the following link to access the strobelight profile at the end of the run: %s",
  75. (
  76. "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime"
  77. "ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no"
  78. "w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%"
  79. "22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n"
  80. "amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%"
  81. "2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22"
  82. "%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C"
  83. "%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang"
  84. "eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2"
  85. "C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack"
  86. "_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param"
  87. "_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22"
  88. "edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde"
  89. "r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[["
  90. "%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%"
  91. f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b"
  92. "_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi"
  93. "ew=GraphProfilerView&&normalized=1712358002&pool=uber"
  94. ),
  95. )
  96. @classmethod
  97. def _log_stats(cls) -> None:
  98. logger.info(
  99. "%s strobelight success runs out of %s non-recursive compilation events.",
  100. cls.success_profile_count,
  101. cls.success_profile_count + cls.failed_profile_count,
  102. )
  103. # TODO use threadlevel meta data to tags to record phases.
  104. @classmethod
  105. def profile_compile_time(
  106. cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
  107. ) -> Any:
  108. if not cls.enabled:
  109. return func(*args, **kwargs)
  110. if cls.profiler is None:
  111. logger.error("profiler is not set")
  112. return
  113. if cls.inside_profile_compile_time:
  114. cls.ignored_profile_runs += 1
  115. logger.info(
  116. "profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
  117. phase_name,
  118. cls.current_phase,
  119. )
  120. return func(*args, **kwargs)
  121. cls.inside_profile_compile_time = True
  122. cls.current_phase = phase_name
  123. work_result = cls.profiler.profile(func, *args, **kwargs)
  124. if cls.profiler.profile_result is not None:
  125. cls.success_profile_count += 1
  126. else:
  127. cls.failed_profile_count += 1
  128. cls._log_stats()
  129. cls.inside_profile_compile_time = False
  130. return work_result