_registrations.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # flake8: noqa: B950
  2. from ._internal import register_artifact, register_log
  3. DYNAMIC = [
  4. "torch.fx.experimental.symbolic_shapes",
  5. "torch.fx.experimental.sym_node",
  6. "torch.fx.experimental.recording",
  7. ]
  8. DISTRIBUTED = [
  9. "torch.distributed",
  10. "torch._dynamo.backends.distributed",
  11. "torch.nn.parallel.distributed",
  12. ]
  13. register_log("dynamo", ["torch._dynamo", *DYNAMIC])
  14. register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
  15. register_log("autograd", "torch.autograd")
  16. register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"])
  17. register_artifact(
  18. "cudagraphs",
  19. "Logs information from wrapping inductor generated code with cudagraphs.",
  20. )
  21. register_log("dynamic", DYNAMIC)
  22. register_log("torch", "torch")
  23. register_log("distributed", DISTRIBUTED)
  24. register_log(
  25. "c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"]
  26. )
  27. register_log(
  28. "ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"]
  29. )
  30. register_log("pp", ["torch.distributed.pipelining"])
  31. register_log("fsdp", ["torch.distributed.fsdp"])
  32. register_log("onnx", "torch.onnx")
  33. register_log("export", ["torch._dynamo", "torch.export", *DYNAMIC])
  34. register_artifact(
  35. "guards",
  36. "This prints the guards for every compiled Dynamo frame. It does not tell you where the guards come from.",
  37. visible=True,
  38. )
  39. register_artifact("verbose_guards", "", off_by_default=True)
  40. register_artifact(
  41. "bytecode",
  42. "Prints the original and modified bytecode from Dynamo. Mostly useful if you're debugging our bytecode generation in Dynamo.",
  43. off_by_default=True,
  44. )
  45. register_artifact(
  46. "graph",
  47. "Prints the dynamo traced graph (prior to AOTDispatch) in a table. If you prefer python code use `graph_code` instead. ",
  48. )
  49. register_artifact("graph_code", "Like `graph`, but gives you the Python code instead.")
  50. register_artifact(
  51. "graph_sizes", "Prints the sizes of all FX nodes in the dynamo graph."
  52. )
  53. register_artifact(
  54. "trace_source",
  55. "As we execute bytecode, prints the file name / line number we are processing and the actual source code. Useful with `bytecode`",
  56. )
  57. register_artifact(
  58. "trace_call",
  59. "Like trace_source, but it will give you the per-expression blow-by-blow if your Python is recent enough.",
  60. )
  61. register_artifact(
  62. "trace_bytecode",
  63. "As we trace bytecode, prints the instruction and the current stack.",
  64. )
  65. register_artifact(
  66. "aot_graphs",
  67. "Prints the FX forward and backward graph generated by AOTDispatch, after partitioning. Useful to understand what's being given to Inductor",
  68. visible=True,
  69. )
  70. register_artifact(
  71. "aot_joint_graph",
  72. "Print FX joint graph from AOTAutograd, prior to partitioning. Useful for debugging partitioning",
  73. )
  74. register_artifact(
  75. "post_grad_graphs",
  76. "Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes",
  77. )
  78. register_artifact(
  79. "compiled_autograd",
  80. "Prints various logs in compiled_autograd, including but not limited to the graphs. Useful for debugging compiled_autograd.",
  81. visible=True,
  82. )
  83. register_artifact(
  84. "compiled_autograd_verbose",
  85. "Will affect performance. Prints compiled_autograd logs with C++ info e.g. autograd node -> fx node mapping",
  86. off_by_default=True,
  87. )
  88. register_artifact(
  89. "ddp_graphs",
  90. "Only relevant for compiling DDP. DDP splits into multiple graphs to trigger comms early. This will print each individual graph here.",
  91. )
  92. register_artifact(
  93. "recompiles",
  94. "Prints the reason why we recompiled a graph. Very, very useful.",
  95. visible=True,
  96. )
  97. register_artifact(
  98. "recompiles_verbose",
  99. "Prints all guard checks that fail during a recompilation. "
  100. "At runtime, Dynamo will stop at the first failed check for each failing guard. "
  101. "So not all logged failing checks are actually ran by Dynamo.",
  102. visible=True,
  103. off_by_default=True,
  104. )
  105. register_artifact(
  106. "graph_breaks",
  107. "Prints whenever Dynamo decides that it needs to graph break (i.e. create a new graph). Useful for debugging why torch.compile has poor performance",
  108. visible=True,
  109. )
  110. register_artifact(
  111. "not_implemented",
  112. "Prints log messages whenever we return NotImplemented in a multi-dispatch, letting you trace through each object we attempted to dispatch to",
  113. )
  114. register_artifact(
  115. "output_code",
  116. "Prints the code that Inductor generates (either Triton or C++)",
  117. off_by_default=True,
  118. visible=True,
  119. )
  120. register_artifact(
  121. "kernel_code",
  122. "Prints the code that Inductor generates (on a per-kernel basis)",
  123. off_by_default=True,
  124. visible=True,
  125. )
  126. register_artifact(
  127. "schedule",
  128. "Inductor scheduler information. Useful if working on Inductor fusion algo",
  129. off_by_default=True,
  130. )
  131. register_artifact("perf_hints", "", off_by_default=True)
  132. register_artifact("onnx_diagnostics", "", off_by_default=True)
  133. register_artifact(
  134. "fusion",
  135. "Detailed Inductor fusion decisions. More detailed than 'schedule'",
  136. off_by_default=True,
  137. )
  138. register_artifact(
  139. "overlap",
  140. "Detailed Inductor compute/comm overlap decisions",
  141. off_by_default=True,
  142. )
  143. register_artifact(
  144. "sym_node",
  145. "Logs extra info for various SymNode operations",
  146. off_by_default=True,
  147. )
  148. register_artifact("custom_format_test_artifact", "Testing only", log_format="")