profiler.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893
  1. # mypy: allow-untyped-defs
  2. import gzip
  3. import json
  4. import os
  5. import shutil
  6. import tempfile
  7. from abc import ABC, abstractmethod
  8. from enum import Enum
  9. from functools import partial
  10. from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
  11. from typing_extensions import Self
  12. from warnings import warn
  13. import torch
  14. import torch.autograd.profiler as prof
  15. from torch._C import _get_privateuse1_backend_name
  16. from torch._C._profiler import (
  17. _add_execution_trace_observer,
  18. _disable_execution_trace_observer,
  19. _enable_execution_trace_observer,
  20. _ExperimentalConfig,
  21. _remove_execution_trace_observer,
  22. )
  23. from torch.autograd import kineto_available, ProfilerActivity
  24. from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
  25. __all__ = [
  26. "supported_activities",
  27. "ProfilerAction",
  28. "schedule",
  29. "tensorboard_trace_handler",
  30. "profile",
  31. "ExecutionTraceObserver",
  32. ]
  33. PROFILER_STEP_NAME = "ProfilerStep"
  34. def supported_activities():
  35. """
  36. Returns a set of supported profiler tracing activities.
  37. Note: profiler uses CUPTI library to trace on-device CUDA kernels.
  38. In case when CUDA is enabled but CUPTI is not available, passing
  39. ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
  40. profiling code (same as in the legacy ``torch.autograd.profiler``).
  41. This, in turn, results in including CUDA time in the profiler table output,
  42. but not in the JSON trace.
  43. """
  44. return torch.autograd._supported_activities()
  45. class _ITraceObserver(ABC):
  46. """Abstract interface for a Trace observer.
  47. This satisfies 3 methods: start, stop and cleanup"""
  48. @abstractmethod
  49. def start(self):
  50. pass
  51. @abstractmethod
  52. def stop(self):
  53. pass
  54. @abstractmethod
  55. def cleanup(self):
  56. pass
  57. class _KinetoProfile:
  58. """Low-level profiler wrap the autograd profile
  59. Args:
  60. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  61. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  62. ``torch.profiler.ProfilerActivity.XPU``.
  63. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  64. or (when available) ProfilerActivity.XPU.
  65. record_shapes (bool): save information about operator's input shapes.
  66. profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
  67. for more details).
  68. with_stack (bool): record source information (file and line number) for the ops.
  69. with_flops (bool): use formula to estimate the FLOPS of specific operators
  70. (matrix multiplication and 2D convolution).
  71. with_modules (bool): record module hierarchy (including function names)
  72. corresponding to the callstack of the op. e.g. If module A's forward call's
  73. module B's forward which contains an aten::add op,
  74. then aten::add's module hierarchy is A.B
  75. Note that this support exist, at the moment, only for TorchScript models
  76. and not eager mode models.
  77. experimental_config (_ExperimentalConfig) : A set of experimental options
  78. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  79. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  80. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  81. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  82. When this argument is included the observer start() and stop() will be called for the
  83. same time window as PyTorch profiler.
  84. .. note::
  85. This API is experimental and subject to change in the future.
  86. Enabling shape and stack tracing results in additional overhead.
  87. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  88. that may further prevent certain optimizations that depend on the reference count and introduce
  89. extra tensor copies.
  90. """
  91. def __init__(
  92. self,
  93. *,
  94. activities: Optional[Iterable[ProfilerActivity]] = None,
  95. record_shapes: bool = False,
  96. profile_memory: bool = False,
  97. with_stack: bool = False,
  98. with_flops: bool = False,
  99. with_modules: bool = False,
  100. experimental_config: Optional[_ExperimentalConfig] = None,
  101. execution_trace_observer: Optional[_ITraceObserver] = None,
  102. ):
  103. self.activities = set(activities) if activities else supported_activities()
  104. self.record_shapes = record_shapes
  105. self.with_flops = with_flops
  106. self.profile_memory = profile_memory
  107. self.with_stack = with_stack
  108. self.with_modules = with_modules
  109. self.experimental_config = experimental_config
  110. self.execution_trace_observer = execution_trace_observer
  111. self.profiler: Optional[prof.profile] = None
  112. self.mem_tl: Optional[MemoryProfileTimeline] = None
  113. self.use_device = None
  114. if ProfilerActivity.CUDA in self.activities:
  115. self.use_device = "cuda"
  116. elif ProfilerActivity.XPU in self.activities:
  117. self.use_device = "xpu"
  118. elif ProfilerActivity.PrivateUse1 in self.activities:
  119. self.use_device = _get_privateuse1_backend_name()
  120. # user-defined metadata to be amended to the trace
  121. self.preset_metadata: Dict[str, str] = dict()
  122. def start(self):
  123. self.prepare_trace()
  124. self.start_trace()
  125. def stop(self):
  126. self.stop_trace()
  127. def prepare_trace(self):
  128. if self.profiler is None:
  129. self.profiler = prof.profile(
  130. use_cpu=(ProfilerActivity.CPU in self.activities),
  131. use_mtia=(ProfilerActivity.MTIA in self.activities),
  132. use_device=self.use_device,
  133. record_shapes=self.record_shapes,
  134. with_flops=self.with_flops,
  135. profile_memory=self.profile_memory,
  136. with_stack=self.with_stack,
  137. with_modules=self.with_modules,
  138. use_kineto=True,
  139. experimental_config=self.experimental_config,
  140. )
  141. self.profiler._prepare_trace()
  142. def start_trace(self):
  143. if self.execution_trace_observer:
  144. self.execution_trace_observer.start()
  145. assert self.profiler is not None
  146. self.profiler._start_trace()
  147. if self.profile_memory:
  148. self.add_metadata_json("profile_memory", "1")
  149. if self.with_stack:
  150. self.add_metadata_json("with_stack", "1")
  151. if self.record_shapes:
  152. self.add_metadata_json("record_shapes", "1")
  153. if self.with_modules:
  154. self.add_metadata_json("with_modules", "1")
  155. if self.with_flops:
  156. self.add_metadata_json("with_flops", "1")
  157. if kineto_available():
  158. dist_info = self._get_distributed_info()
  159. if dist_info:
  160. self.add_metadata_json("distributedInfo", json.dumps(dist_info))
  161. if hasattr(torch, "_inductor"):
  162. import torch._inductor.config as inductor_config
  163. if inductor_config.triton.cudagraphs:
  164. os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
  165. self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
  166. # FIXME: CUDA Graph does not work well with CUPTI teardown.
  167. # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
  168. # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
  169. # Workaround: turn off CUPTI teardown when using CUDA Graphs.
  170. os.environ["TEARDOWN_CUPTI"] = "0"
  171. # Insert the preset user metadata to the trace
  172. for k, v in self.preset_metadata.items():
  173. self.add_metadata_json(k, v)
  174. def stop_trace(self):
  175. if self.execution_trace_observer:
  176. self.execution_trace_observer.stop()
  177. assert self.profiler is not None
  178. self.profiler.__exit__(None, None, None)
  179. def export_chrome_trace(self, path: str):
  180. """
  181. Exports the collected trace in Chrome JSON format. If kineto is enabled, only
  182. last cycle in schedule is exported.
  183. """
  184. assert self.profiler
  185. if path.endswith(".gz"):
  186. fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
  187. fp.close()
  188. retvalue = self.profiler.export_chrome_trace(fp.name)
  189. with open(fp.name) as fin:
  190. with gzip.open(path, "wt") as fout:
  191. fout.writelines(fin)
  192. os.remove(fp.name)
  193. return retvalue
  194. else:
  195. return self.profiler.export_chrome_trace(path)
  196. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  197. """Save stack traces to a file
  198. Args:
  199. path (str): save stacks file to this location;
  200. metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
  201. """
  202. assert self.profiler
  203. return self.profiler.export_stacks(path, metric)
  204. def key_averages(
  205. self, group_by_input_shape: bool = False, group_by_stack_n: int = 0
  206. ):
  207. """Averages events, grouping them by operator name and (optionally) input shapes and
  208. stack.
  209. .. note::
  210. To use shape/stack functionality make sure to set record_shapes/with_stack
  211. when creating profiler context manager.
  212. """
  213. assert self.profiler
  214. return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)
  215. def events(self):
  216. """
  217. Returns the list of unaggregated profiler events,
  218. to be used in the trace callback or after the profiling is finished
  219. """
  220. assert self.profiler
  221. return self.profiler.function_events
  222. def add_metadata(self, key: str, value: str):
  223. """
  224. Adds a user defined metadata with a string key and a string value
  225. into the trace file
  226. """
  227. wrapped_value = '"' + value.replace('"', '\\"') + '"'
  228. torch.autograd._add_metadata_json(key, wrapped_value)
  229. def add_metadata_json(self, key: str, value: str):
  230. """
  231. Adds a user defined metadata with a string key and a valid json value
  232. into the trace file
  233. """
  234. torch.autograd._add_metadata_json(key, value)
  235. def preset_metadata_json(self, key: str, value: str):
  236. """
  237. Preset a user defined metadata when the profiler is not started
  238. and added into the trace file later.
  239. Metadata is in the format of a string key and a valid json value
  240. """
  241. self.preset_metadata[key] = value
  242. def _get_distributed_info(self):
  243. import torch.distributed as dist
  244. if not dist.is_available() or not dist.is_initialized():
  245. return None
  246. backend = dist.get_backend()
  247. dist_info = {
  248. "backend": backend,
  249. "rank": dist.get_rank(),
  250. "world_size": dist.get_world_size(),
  251. "pg_count": dist.get_pg_count(),
  252. "pg_config": dist.distributed_c10d._get_all_pg_configs(),
  253. }
  254. if backend == "nccl":
  255. nccl_version = torch.cuda.nccl.version()
  256. dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
  257. return dist_info
  258. def _memory_profile(self) -> MemoryProfile:
  259. required = ("record_shapes", "profile_memory", "with_stack")
  260. missing = [f"{i}=True" for i in required if not getattr(self, i)]
  261. if missing:
  262. raise ValueError(f"{', '.join(missing)} required for memory profiling.")
  263. assert self.profiler is not None and self.profiler.kineto_results is not None
  264. return MemoryProfile(self.profiler.kineto_results)
  265. def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
  266. """Export memory event information from the profiler collected
  267. tree for a given device, and export a timeline plot. There are 3
  268. exportable files using ``export_memory_timeline``, each controlled by the
  269. ``path``'s suffix.
  270. - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
  271. plot will be embedded as a PNG file in the HTML file.
  272. - For plot points consisting of ``[times, [sizes by category]]``, where
  273. ``times`` are timestamps and ``sizes`` are memory usage for each category.
  274. The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
  275. (``.json.gz``) depending on the suffix.
  276. - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
  277. event will consist of ``(timestamp, action, numbytes, category)``, where
  278. ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
  279. and ``category`` is one of the enums from
  280. ``torch.profiler._memory_profiler.Category``.
  281. Output: Memory timeline written as gzipped JSON, JSON, or HTML.
  282. """
  283. # Default to device 0, if unset. Fallback on cpu.
  284. if device is None and self.use_device and self.use_device != "cuda":
  285. device = self.use_device + ":0"
  286. if device is None:
  287. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  288. # Construct the memory timeline plot data
  289. self.mem_tl = MemoryProfileTimeline(self._memory_profile())
  290. # Depending on the file suffix, save the data as json.gz or json.
  291. # For html, we can embed the image into an HTML file.
  292. if path.endswith(".html"):
  293. self.mem_tl.export_memory_timeline_html(path, device)
  294. elif path.endswith(".gz"):
  295. fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
  296. fp.close()
  297. if path.endswith("raw.json.gz"):
  298. self.mem_tl.export_memory_timeline_raw(fp.name, device)
  299. else:
  300. self.mem_tl.export_memory_timeline(fp.name, device)
  301. with open(fp.name) as fin:
  302. with gzip.open(path, "wt") as fout:
  303. fout.writelines(fin)
  304. os.remove(fp.name)
  305. else:
  306. self.mem_tl.export_memory_timeline(path, device)
  307. class ProfilerAction(Enum):
  308. """
  309. Profiler actions that can be taken at the specified intervals
  310. """
  311. NONE = 0
  312. WARMUP = 1
  313. RECORD = 2
  314. RECORD_AND_SAVE = 3
  315. def schedule(
  316. *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0
  317. ) -> Callable:
  318. """
  319. Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
  320. the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
  321. then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
  322. The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
  323. the cycles will continue until the profiling is finished.
  324. """
  325. def schedule_fn(step: int) -> ProfilerAction:
  326. assert step >= 0
  327. if step < skip_first:
  328. return ProfilerAction.NONE
  329. else:
  330. step -= skip_first
  331. num_steps = wait + warmup + active
  332. if repeat > 0 and step / num_steps >= repeat:
  333. return ProfilerAction.NONE
  334. mod_step = step % num_steps
  335. if mod_step < wait:
  336. return ProfilerAction.NONE
  337. elif mod_step < wait + warmup:
  338. return ProfilerAction.WARMUP
  339. else:
  340. return (
  341. ProfilerAction.RECORD
  342. if mod_step < num_steps - 1
  343. else ProfilerAction.RECORD_AND_SAVE
  344. )
  345. assert (
  346. wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
  347. ), "Invalid profiler schedule arguments"
  348. if warmup == 0:
  349. warn("Profiler won't be using warmup, this can skew profiler results")
  350. return schedule_fn
  351. def _default_schedule_fn(_: int) -> ProfilerAction:
  352. """
  353. Default profiler behavior - immediately starts recording the events,
  354. keeps doing it on every profiler step.
  355. """
  356. return ProfilerAction.RECORD
  357. def tensorboard_trace_handler(
  358. dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False
  359. ):
  360. """
  361. Outputs tracing files to directory of ``dir_name``, then that directory can be
  362. directly delivered to tensorboard as logdir.
  363. ``worker_name`` should be unique for each worker in distributed scenario,
  364. it will be set to '[hostname]_[pid]' by default.
  365. """
  366. import os
  367. import socket
  368. import time
  369. def handler_fn(prof) -> None:
  370. nonlocal worker_name
  371. if not os.path.isdir(dir_name):
  372. try:
  373. os.makedirs(dir_name, exist_ok=True)
  374. except Exception as e:
  375. raise RuntimeError("Can't create directory: " + dir_name) from e
  376. if not worker_name:
  377. worker_name = f"{socket.gethostname()}_{os.getpid()}"
  378. # Use nanosecond here to avoid naming clash when exporting the trace
  379. file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
  380. if use_gzip:
  381. file_name = file_name + ".gz"
  382. prof.export_chrome_trace(os.path.join(dir_name, file_name))
  383. return handler_fn
  384. class profile(_KinetoProfile):
  385. """Profiler context manager.
  386. Args:
  387. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  388. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  389. ``torch.profiler.ProfilerActivity.XPU``.
  390. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  391. or (when available) ProfilerActivity.XPU.
  392. schedule (Callable): callable that takes step (int) as a single parameter and returns
  393. ``ProfilerAction`` value that specifies the profiler action to perform at each step.
  394. on_trace_ready (Callable): callable that is called at each step when ``schedule``
  395. returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
  396. record_shapes (bool): save information about operator's input shapes.
  397. profile_memory (bool): track tensor memory allocation/deallocation.
  398. with_stack (bool): record source information (file and line number) for the ops.
  399. with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
  400. (matrix multiplication and 2D convolution).
  401. with_modules (bool): record module hierarchy (including function names)
  402. corresponding to the callstack of the op. e.g. If module A's forward call's
  403. module B's forward which contains an aten::add op,
  404. then aten::add's module hierarchy is A.B
  405. Note that this support exist, at the moment, only for TorchScript models
  406. and not eager mode models.
  407. experimental_config (_ExperimentalConfig) : A set of experimental options
  408. used for Kineto library features. Note, backward compatibility is not guaranteed.
  409. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  410. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  411. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  412. When this argument is included the observer start() and stop() will be called for the
  413. same time window as PyTorch profiler. See the examples section below for a code sample.
  414. use_cuda (bool):
  415. .. deprecated:: 1.8.1
  416. use ``activities`` instead.
  417. .. note::
  418. Use :func:`~torch.profiler.schedule` to generate the callable schedule.
  419. Non-default schedules are useful when profiling long training jobs
  420. and allow the user to obtain multiple traces at the different iterations
  421. of the training process.
  422. The default schedule simply records all the events continuously for the
  423. duration of the context manager.
  424. .. note::
  425. Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
  426. ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
  427. After profiling, result files can be found in the specified directory. Use the command:
  428. ``tensorboard --logdir dir_name``
  429. to see the results in TensorBoard.
  430. For more information, see
  431. `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
  432. .. note::
  433. Enabling shape and stack tracing results in additional overhead.
  434. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  435. that may further prevent certain optimizations that depend on the reference count and introduce
  436. extra tensor copies.
  437. Examples:
  438. .. code-block:: python
  439. with torch.profiler.profile(
  440. activities=[
  441. torch.profiler.ProfilerActivity.CPU,
  442. torch.profiler.ProfilerActivity.CUDA,
  443. ]
  444. ) as p:
  445. code_to_profile()
  446. print(p.key_averages().table(
  447. sort_by="self_cuda_time_total", row_limit=-1))
  448. Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
  449. .. code-block:: python
  450. # Non-default profiler schedule allows user to turn profiler on and off
  451. # on different iterations of the training loop;
  452. # trace_handler is called every time a new trace becomes available
  453. def trace_handler(prof):
  454. print(prof.key_averages().table(
  455. sort_by="self_cuda_time_total", row_limit=-1))
  456. # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
  457. with torch.profiler.profile(
  458. activities=[
  459. torch.profiler.ProfilerActivity.CPU,
  460. torch.profiler.ProfilerActivity.CUDA,
  461. ],
  462. # In this example with wait=1, warmup=1, active=2, repeat=1,
  463. # profiler will skip the first step/iteration,
  464. # start warming up on the second, record
  465. # the third and the forth iterations,
  466. # after which the trace will become available
  467. # and on_trace_ready (when set) is called;
  468. # the cycle repeats starting with the next step
  469. schedule=torch.profiler.schedule(
  470. wait=1,
  471. warmup=1,
  472. active=2,
  473. repeat=1),
  474. on_trace_ready=trace_handler
  475. # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
  476. # used when outputting for tensorboard
  477. ) as p:
  478. for iter in range(N):
  479. code_iteration_to_profile(iter)
  480. # send a signal to the profiler that the next iteration has started
  481. p.step()
  482. The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
  483. .. code-block:: python
  484. with torch.profiler.profile(
  485. ...
  486. execution_trace_observer=(
  487. ExecutionTraceObserver().register_callback("./execution_trace.json")
  488. ),
  489. ) as p:
  490. for iter in range(N):
  491. code_iteration_to_profile(iter)
  492. p.step()
  493. You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py.
  494. Note: One can also pass any object satisfying the _ITraceObserver interface.
  495. """
  496. def __init__(
  497. self,
  498. *,
  499. activities: Optional[Iterable[ProfilerActivity]] = None,
  500. schedule: Optional[Callable[[int], ProfilerAction]] = None,
  501. on_trace_ready: Optional[Callable[..., Any]] = None,
  502. record_shapes: bool = False,
  503. profile_memory: bool = False,
  504. with_stack: bool = False,
  505. with_flops: bool = False,
  506. with_modules: bool = False,
  507. experimental_config: Optional[_ExperimentalConfig] = None,
  508. execution_trace_observer: Optional[_ITraceObserver] = None,
  509. # deprecated:
  510. use_cuda: Optional[bool] = None,
  511. ):
  512. activities_set = set(activities) if activities else supported_activities()
  513. if use_cuda is not None:
  514. warn(
  515. "`use_cuda` is deprecated, use `activities` argument instead",
  516. FutureWarning,
  517. stacklevel=2,
  518. )
  519. if use_cuda:
  520. activities_set.add(ProfilerActivity.CUDA)
  521. elif ProfilerActivity.CUDA in activities_set:
  522. activities_set.remove(ProfilerActivity.CUDA)
  523. assert len(activities_set) > 0, "No valid profiler activities found"
  524. super().__init__(
  525. activities=activities,
  526. record_shapes=record_shapes,
  527. profile_memory=profile_memory,
  528. with_stack=with_stack,
  529. with_flops=with_flops,
  530. with_modules=with_modules,
  531. experimental_config=experimental_config,
  532. execution_trace_observer=execution_trace_observer,
  533. )
  534. if schedule:
  535. self.schedule = schedule
  536. # add step markers into the trace and table view
  537. self.record_steps = True
  538. else:
  539. self.schedule = _default_schedule_fn
  540. self.record_steps = False
  541. self.on_trace_ready = on_trace_ready
  542. self.step_num = 0
  543. self.current_action = self.schedule(self.step_num)
  544. self.step_rec_fn: Optional[prof.record_function] = None
  545. self.action_map: Dict[
  546. Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]
  547. ] = {
  548. # key is (prev_action, current_action), value is action list corresponding to the state pair.
  549. (ProfilerAction.NONE, ProfilerAction.NONE): [],
  550. (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
  551. (ProfilerAction.NONE, ProfilerAction.RECORD): [
  552. self.prepare_trace,
  553. self.start_trace,
  554. ],
  555. (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [
  556. self.prepare_trace,
  557. self.start_trace,
  558. ],
  559. (ProfilerAction.WARMUP, ProfilerAction.NONE): [
  560. partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
  561. self.start_trace,
  562. self.stop_trace,
  563. ],
  564. (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
  565. (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
  566. (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
  567. (ProfilerAction.RECORD, ProfilerAction.NONE): [
  568. partial(warn, "Incorrect schedule: RECORD followed by NONE"),
  569. self.stop_trace,
  570. ],
  571. (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
  572. partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
  573. self.stop_trace,
  574. ],
  575. (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
  576. (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
  577. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [
  578. self.stop_trace,
  579. self._trace_ready,
  580. ],
  581. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [
  582. self.stop_trace,
  583. self._trace_ready,
  584. self.prepare_trace,
  585. ],
  586. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
  587. self.stop_trace,
  588. self._trace_ready,
  589. self.prepare_trace,
  590. self.start_trace,
  591. ],
  592. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
  593. self.stop_trace,
  594. self._trace_ready,
  595. self.prepare_trace,
  596. self.start_trace,
  597. ],
  598. # used for exit action
  599. (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
  600. (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
  601. (ProfilerAction.RECORD_AND_SAVE, None): [
  602. self.stop_trace,
  603. self._trace_ready,
  604. ],
  605. }
  606. # Start tracking increments to profiler step, this will be used
  607. # by Kineto
  608. prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
  609. def __enter__(self):
  610. self.start()
  611. return self
  612. def __exit__(self, exc_type, exc_val, exc_tb):
  613. self.stop()
  614. prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
  615. if self.execution_trace_observer:
  616. self.execution_trace_observer.cleanup()
  617. def start(self):
  618. self._transit_action(ProfilerAction.NONE, self.current_action)
  619. if self.record_steps:
  620. self.step_rec_fn = prof.record_function(
  621. "ProfilerStep#" + str(self.step_num)
  622. )
  623. self.step_rec_fn.__enter__()
  624. def stop(self):
  625. if self.record_steps and self.step_rec_fn:
  626. self.step_rec_fn.__exit__(None, None, None)
  627. self._transit_action(self.current_action, None)
  628. def step(self):
  629. """
  630. Signals the profiler that the next profiling step has started.
  631. """
  632. if self.record_steps and self.step_rec_fn:
  633. self.step_rec_fn.__exit__(None, None, None)
  634. prev_action = self.current_action
  635. self.step_num += 1
  636. self.current_action = self.schedule(self.step_num)
  637. self._transit_action(prev_action, self.current_action)
  638. prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
  639. if self.record_steps:
  640. self.step_rec_fn = prof.record_function(
  641. "ProfilerStep#" + str(self.step_num)
  642. )
  643. self.step_rec_fn.__enter__()
  644. def _trace_ready(self):
  645. if self.on_trace_ready:
  646. self.on_trace_ready(self)
  647. def _transit_action(self, prev_action, current_action):
  648. action_list = self.action_map.get((prev_action, current_action))
  649. if action_list:
  650. for action in action_list:
  651. action()
  652. def _stats(self) -> Optional[prof._ProfilerStats]:
  653. if self.profiler is None:
  654. return None
  655. return self.profiler._stats
  656. class ExecutionTraceObserver(_ITraceObserver):
  657. """Execution Trace Observer
  658. Each process can have a single ExecutionTraceObserver instance. The observer
  659. can be added to record function callbacks via calling register_callback()
  660. explicitly. Without calling unregister_callback(), repeated calls to
  661. register_callback() will not add additional observers to record function
  662. callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
  663. methods control when the event data is recorded.
  664. Deleting or calling unregister_callback() will remove the observer from the
  665. record function callbacks, finalize the output file, and will stop
  666. incurring any overheads.
  667. """
  668. def __init__(self):
  669. """
  670. Initializes the default states.
  671. """
  672. self._registered = False
  673. self._execution_trace_running = False
  674. def __del__(self):
  675. """
  676. Calls unregister_callback() to make sure to finalize outputs.
  677. """
  678. self.unregister_callback()
  679. def register_callback(self, output_file_path: str) -> Self:
  680. """
  681. Adds ET observer to record function callbacks. The data will be
  682. written to output_file_path.
  683. """
  684. if not self._registered:
  685. self._output_file_path = output_file_path
  686. self._registered = _add_execution_trace_observer(output_file_path)
  687. return self
  688. def unregister_callback(self):
  689. """
  690. Removes ET observer from record function callbacks.
  691. """
  692. def _save_triton_kernels():
  693. # Save the kernel paths for the generated kernels
  694. from torch._inductor.codecache import PyCodeCache as PyCodeCache
  695. kernel_files = [
  696. v.__file__
  697. for v in PyCodeCache.cache.values()
  698. if getattr(v, "__file__", None) is not None
  699. ]
  700. work_dir, file_name = os.path.split(self._output_file_path)
  701. resource_dir = os.path.join(
  702. work_dir, os.path.splitext(file_name)[0] + "_resources"
  703. )
  704. if not os.path.exists(resource_dir):
  705. os.mkdir(resource_dir)
  706. for kernel_file in kernel_files:
  707. if kernel_file is None:
  708. continue
  709. path, name = os.path.split(kernel_file)
  710. dst = os.path.join(resource_dir, name)
  711. shutil.copyfile(kernel_file, dst)
  712. if self._registered:
  713. self.stop()
  714. try:
  715. _save_triton_kernels()
  716. except Exception as e:
  717. warn(f"Execution trace failed to save kernels: {e}")
  718. _remove_execution_trace_observer()
  719. self._registered = False
  720. @property
  721. def is_registered(self):
  722. """
  723. Returns True if the execution trace observer is registered, otherwise False.
  724. """
  725. return self._registered
  726. def is_running(self):
  727. """
  728. Returns True if the observer is running, otherwise False.
  729. """
  730. return self._execution_trace_running
  731. def start(self):
  732. """
  733. Starts to capture.
  734. """
  735. if self._registered and not self._execution_trace_running:
  736. _enable_execution_trace_observer()
  737. self._execution_trace_running = True
  738. self._record_pg_config()
  739. def stop(self):
  740. """
  741. Stops to capture.
  742. """
  743. if self._execution_trace_running:
  744. _disable_execution_trace_observer()
  745. self._execution_trace_running = False
  746. def cleanup(self):
  747. """
  748. Calls unregister_callback() to make sure to finalize outputs.
  749. """
  750. self.unregister_callback()
  751. def get_output_file_path(self) -> str:
  752. """
  753. Returns the output file name.
  754. """
  755. if self.is_registered:
  756. return self._output_file_path
  757. else:
  758. raise RuntimeError(
  759. "A callback to the ET profiler needs to be registered "
  760. "first before getting the output file path"
  761. )
  762. def _record_pg_config(self) -> None:
  763. # Records the PG config info to the trace as node:
  764. # ## process_group:init ##
  765. if (
  766. self.is_registered
  767. and torch.distributed.is_available()
  768. and torch.distributed.is_initialized()
  769. ):
  770. pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
  771. torch.autograd._record_function_with_args_enter(
  772. "## process_group:init ##", json.dumps(pg_config_info)
  773. )