_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import operator
  4. import re
  5. from collections import deque
  6. from dataclasses import dataclass
  7. from typing import Dict, List, TYPE_CHECKING
  8. from torch.autograd.profiler import profile
  9. from torch.profiler import DeviceType
  10. if TYPE_CHECKING:
  11. from torch.autograd import _KinetoEvent
  12. def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
  13. order = reversed if reverse else lambda x: x
  14. remaining = deque(order(tree))
  15. while remaining:
  16. curr_event = next_fn(remaining)
  17. yield curr_event
  18. for child_event in order(children_fn(curr_event)):
  19. remaining.append(child_event)
  20. traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True)
  21. traverse_bfs = functools.partial(
  22. _traverse, next_fn=lambda x: x.popleft(), reverse=False
  23. )
  24. @dataclass
  25. class EventMetrics:
  26. duration_time_ns: int = 0
  27. self_time_ns: int = 0
  28. idle_time_ns: int = 0
  29. queue_depth: int = 0
  30. @property
  31. def fraction_idle_time(self):
  32. if self.duration_time_ns == 0:
  33. return 0.0
  34. return self.idle_time_ns / self.duration_time_ns
  35. @dataclass
  36. class Interval:
  37. start: int
  38. end: int
  39. queue_depth: int = 0
  40. class EventKey:
  41. def __init__(self, event):
  42. self.event = event
  43. def __hash__(self):
  44. return hash(self.event.id)
  45. def __eq__(self, other):
  46. return self.event.id == other.event.id
  47. def __repr__(self):
  48. return f"{self.event.name}"
  49. def intervals_overlap(self, intervals: List[Interval]):
  50. overlap_time = 0
  51. intervals = sorted(intervals, key=lambda x: x.start)
  52. if intervals:
  53. overlap_start = max(self.event.start_time_ns, intervals[0].start)
  54. overlap_end = min(self.event.end_time_ns, intervals[0].end)
  55. if overlap_start < overlap_end:
  56. overlap_time += overlap_end - overlap_start
  57. i, j = 0, 1
  58. while j < len(intervals):
  59. prev_interval = intervals[i]
  60. curr_interval = intervals[j]
  61. j += 1
  62. if prev_interval.end > curr_interval.start:
  63. # Completely subsumed by previous interval
  64. if prev_interval.end > curr_interval.end:
  65. j += 1
  66. continue
  67. else:
  68. curr_interval.start = prev_interval.end
  69. i = j
  70. overlap_start = max(self.event.start_time_ns, curr_interval.start)
  71. overlap_end = min(self.event.end_time_ns, curr_interval.end)
  72. if overlap_start < overlap_end:
  73. overlap_time += overlap_end - overlap_start
  74. return overlap_time
  75. class BasicEvaluation:
  76. def __init__(self, prof: profile):
  77. self.profile = prof
  78. self.metrics: Dict[EventKey, EventMetrics] = {}
  79. self.compute_self_time()
  80. self.event_keys = sorted(
  81. (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns
  82. )
  83. self.events = [e.event for e in self.event_keys]
  84. self.cuda_events: List[_KinetoEvent] = []
  85. self.queue_depth_list = self.compute_queue_depth()
  86. self.compute_idle_time()
  87. def compute_self_time(self):
  88. """
  89. Computes event's self time(total time - time in child ops).
  90. """
  91. assert self.profile.kineto_results is not None
  92. stack = deque(self.profile.kineto_results.experimental_event_tree())
  93. # standard iterating dfs
  94. while stack:
  95. curr_event = stack.pop()
  96. self_time = curr_event.duration_time_ns
  97. for child_event in curr_event.children:
  98. self_time -= child_event.duration_time_ns
  99. stack.append(child_event)
  100. assert (
  101. EventKey(curr_event) not in self.metrics
  102. ), f"Duplicate id: {curr_event.id}, {curr_event.name}"
  103. self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
  104. self.metrics[
  105. EventKey(curr_event)
  106. ].duration_time_ns = curr_event.duration_time_ns
  107. def compute_queue_depth(self):
  108. """
  109. Computes queue_depth at each event. This will calculate the queue depth data for
  110. All the events in the tree.
  111. This will return a list of Interval of queue depth data of cuda launch and kernels.
  112. """
  113. assert self.profile.kineto_results is not None
  114. cuda_event_list = self.profile.kineto_results.events()
  115. def is_cuda_launch_kernel(e):
  116. # TODO: find a better way to identify cudaLaunchKernel
  117. return e.name == "cudaLaunchKernel"
  118. def is_cuda_kernel(e):
  119. # TODO: find a better way to identify CUDA Kernel
  120. return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower()
  121. cuda_launch_events = sorted(
  122. (e for e in cuda_event_list if is_cuda_launch_kernel(e)),
  123. key=lambda x: x.start_ns(),
  124. )
  125. cuda_kernel_events = sorted(
  126. (e for e in cuda_event_list if is_cuda_kernel(e)),
  127. key=lambda x: x.start_ns(),
  128. )
  129. self.cuda_events = sorted(
  130. cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
  131. )
  132. kernel_mapping: Dict[_KinetoEvent, int] = {}
  133. last_mapped_kernel = 0
  134. for cuda_launch_event in cuda_launch_events:
  135. index = index_of_first_match(
  136. cuda_kernel_events,
  137. lambda x: x.linked_correlation_id()
  138. == cuda_launch_event.linked_correlation_id(),
  139. start=last_mapped_kernel,
  140. )
  141. kernel_mapping[cuda_launch_event] = index
  142. last_mapped_kernel = index if index is not None else last_mapped_kernel
  143. current_kernel_index = 0
  144. spawned_kernel_index = -1
  145. all_events = cuda_launch_events + cuda_kernel_events + self.events
  146. def new_old_event_comparator(event):
  147. if hasattr(event, "start_us"):
  148. return event.start_us() * 1000
  149. if hasattr(event, "start_ns"):
  150. return event.start_ns()
  151. if hasattr(event, "start_time_ns"):
  152. return event.start_time_ns
  153. raise Exception("Unknown Event Type") # noqa: TRY002
  154. queue_depth_list: List[Interval] = []
  155. all_events.sort(key=new_old_event_comparator)
  156. for event in all_events:
  157. # Find latest cuda kernel event
  158. if hasattr(event, "start_us"):
  159. start_time = event.start_us() * 1000
  160. end_time = (event.start_us() + event.duration_us()) * 1000
  161. # Find current spawned cuda kernel event
  162. if event in kernel_mapping and kernel_mapping[event] is not None:
  163. spawned_kernel_index = kernel_mapping[event]
  164. if hasattr(event, "start_ns"):
  165. start_time = event.start_ns()
  166. end_time = event.start_ns() + event.duration_ns()
  167. # Find current spawned cuda kernel event
  168. if event in kernel_mapping and kernel_mapping[event] is not None:
  169. spawned_kernel_index = kernel_mapping[event]
  170. elif hasattr(event, "start_time_ns"):
  171. start_time = event.start_time_ns # type: ignore[attr-defined]
  172. end_time = event.end_time_ns # type: ignore[attr-defined]
  173. while (
  174. current_kernel_index < len(cuda_kernel_events)
  175. and (cuda_kernel_events[current_kernel_index].start_ns())
  176. <= start_time # type: ignore[possibly-undefined]
  177. ):
  178. current_kernel_index += 1
  179. current_queue_depth = spawned_kernel_index - current_kernel_index + 1
  180. current_queue_depth = max(current_queue_depth, 0)
  181. if hasattr(event, "start_us") or hasattr(event, "start_ns"):
  182. queue_depth_list.append(
  183. Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
  184. )
  185. elif hasattr(event, "start_time_ns"):
  186. self.metrics[EventKey(event)].queue_depth = current_queue_depth
  187. return queue_depth_list
  188. def compute_idle_time(self):
  189. """
  190. Computes idle time of the profile.
  191. """
  192. # Based on queue_depth_list, we can calculate idle time for all the events
  193. idle = False
  194. idle_start = 0
  195. idle_intervals: List[Interval] = []
  196. if self.queue_depth_list and self.events:
  197. idle_intervals += [
  198. Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),
  199. Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns),
  200. ]
  201. for data_point in self.queue_depth_list:
  202. if data_point.queue_depth == 0 and not idle:
  203. idle_start = data_point.end
  204. idle = True
  205. if data_point.queue_depth > 0 and idle:
  206. idle_intervals.append(Interval(idle_start, data_point.start))
  207. idle = False
  208. event_list = [e.event for e in self.metrics.keys()]
  209. for event in event_list:
  210. self.metrics[EventKey(event)].idle_time_ns = EventKey(
  211. event
  212. ).intervals_overlap(idle_intervals)
  213. def rank_events(self, length):
  214. """
  215. Filter and Rank the events based on some heuristics:
  216. 1) Events that are in the falling phase of the queue depth.
  217. 2) Events that have a high idle_time, self_time difference.
  218. Parameters:
  219. length: The number of events to return.
  220. """
  221. # Find the interval when qd is falling to 0
  222. import torch
  223. queue_depth_list = list(reversed(self.queue_depth_list))
  224. qd_values = [e.queue_depth for e in queue_depth_list]
  225. bottom_threashold = 0
  226. top_threashold = 4
  227. decrease_interval = []
  228. i = 0
  229. while i < len(qd_values):
  230. if qd_values[i] > bottom_threashold:
  231. i += 1
  232. continue
  233. for j in range(i + 1, len(qd_values)):
  234. # Find next zero and if the max value between them exceeds
  235. # the threshold, then we have a falling interval
  236. next_minimum_idx = index_of_first_match(
  237. qd_values, lambda x: x <= bottom_threashold, start=j
  238. )
  239. peak_idx = argmax(qd_values, start=j, end=next_minimum_idx)
  240. # if is a valid peak, we add to list and continue
  241. if peak_idx is not None and qd_values[peak_idx] >= top_threashold:
  242. decrease_interval.append(
  243. Interval(
  244. queue_depth_list[peak_idx].start, queue_depth_list[i].start
  245. )
  246. )
  247. i = next_minimum_idx if next_minimum_idx is not None else i
  248. break
  249. i += 1
  250. # Filter out events that are not in the decrease interval
  251. event_list = [
  252. event
  253. for event in self.metrics.keys()
  254. if event.intervals_overlap(decrease_interval)
  255. ]
  256. if event_list:
  257. self_time = torch.tensor(
  258. [self.metrics[event].self_time_ns for event in event_list],
  259. dtype=torch.float32,
  260. )
  261. idle_time = torch.tensor(
  262. [self.metrics[event].fraction_idle_time for event in event_list],
  263. dtype=torch.float32,
  264. )
  265. normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time)
  266. normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time)
  267. heuristic_score_list = normalized_gain + 0.6 * normalized_self
  268. # Sort events by heuristic
  269. event_list = [
  270. event
  271. for _, event in sorted(
  272. zip(heuristic_score_list, event_list),
  273. key=operator.itemgetter(0),
  274. reverse=True,
  275. )
  276. ]
  277. event_list = event_list[:length]
  278. return event_list
  279. def get_optimizable_events(self, length: int = 1, print_enable: bool = True):
  280. event_list = self.rank_events(length)
  281. if not print_enable:
  282. return event_list
  283. output = "Optimizable events:\n" if event_list else "No events to optimize\n"
  284. output += "\n".join(
  285. [
  286. f"""{'-'*80}
  287. Event: {event}
  288. Source code location: {source_code_location(event.event)}
  289. Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
  290. {'-'*80}"""
  291. for event in event_list
  292. ]
  293. )
  294. if print_enable:
  295. print(output)
  296. return event_list
  297. def index_of_first_match(seq, predicate, start=0, end=None):
  298. if end is None or end >= len(seq):
  299. end = len(seq)
  300. for i in range(start, end):
  301. if predicate(seq[i]):
  302. return i
  303. return None
  304. def argmax(seq, key=lambda x: x, start=0, end=None):
  305. seq = seq[start:end]
  306. if len(seq) == 0:
  307. return None
  308. return seq.index(max(seq, key=key)) + start
  309. def source_code_location(event):
  310. while event is not None:
  311. match = re.search(r"\.py\(.*\)", event.name)
  312. if match is None:
  313. event = event.parent
  314. continue
  315. return event.name
  316. return "No source code location found"
  317. # Provide an OSS workaround for cudagraphs + CUPTI issue
  318. # https://github.com/pytorch/pytorch/issues/75504
  319. # TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when
  320. # we stop supporting older CUDA versions.
  321. def _init_for_cuda_graphs():
  322. from torch.autograd.profiler import profile
  323. with profile():
  324. pass