_internal.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import hashlib
  4. import itertools
  5. import json
  6. import logging
  7. import os
  8. import os.path
  9. import re
  10. import tempfile
  11. from dataclasses import dataclass, field
  12. from importlib import __import__
  13. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  14. from weakref import WeakSet
  15. import torch._logging.structured
  16. from torch.utils._traceback import CapturedTraceback
  17. log = logging.getLogger(__name__)
  18. # This is a synthetic logger which doesn't correspond to an actual logger,
  19. # but handles all of our "tracing" logging, which is structured and doesn't go
  20. # to stderr but always goes to a dedicated log file. We don't put these
  21. # loggers in the classic module hierarchy, because we don't want a suppression
  22. # of logs to also cause a trace to get suppressed (traces typically are not
  23. # collected, unless we are in prod, in which case they always are collected.)
  24. #
  25. # TODO: Maybe we should allow for some sub-hierarchy so you can control which
  26. # traces you want to collect, for performance reasons.
  27. #
  28. # See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
  29. trace_log = logging.getLogger("torch.__trace")
  30. DEFAULT_LOG_LEVEL = logging.WARNING
  31. LOG_ENV_VAR = "TORCH_LOGS"
  32. LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
  33. LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
  34. TRACE_ENV_VAR = "TORCH_TRACE"
  35. @dataclass
  36. class LogRegistry:
  37. # shorthand name to log qualified name
  38. # Note: this only contains loggers registered
  39. # from register_log
  40. # e.g. "dynamo" -> "torch._dynamo"
  41. log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
  42. # artifact logger qualified names,
  43. # this is populated lazily, as calls to getArtifactLogger
  44. # currently formatted as <module>.__<artifact_name>
  45. # e.g. "torch._dynamo.convert_frame.__guards"
  46. artifact_log_qnames: Set[str] = field(default_factory=set)
  47. # child logs of registered logs if specified via open
  48. # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
  49. # these need to be tracked so their levels can be reset properly
  50. # e.g. "torch._dynamo.output_graph"
  51. child_log_qnames: Set[str] = field(default_factory=set)
  52. # artifact names, populated by register_artifact
  53. # e.g. "guards"
  54. artifact_names: Set[str] = field(default_factory=set)
  55. # Artifacts that should be visible by default in the error message
  56. visible_artifacts: Set[str] = field(default_factory=set)
  57. # A short description of each artifact
  58. artifact_descriptions: Dict[str, str] = field(default_factory=dict)
  59. # artifacts which are not displayed unless explicitly named in the
  60. # settings. Ex. output_code is NOT displayed even if the inductor
  61. # log level is set to DEBUG. It must be explicitly named in the settings
  62. off_by_default_artifact_names: Set[str] = field(default_factory=set)
  63. # logging format string for artifacts
  64. artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
  65. def is_artifact(self, name):
  66. return name in self.artifact_names
  67. def is_log(self, alias):
  68. return alias in self.log_alias_to_log_qnames
  69. # register a log with an alias
  70. def register_log(self, alias, log_qnames: Union[str, List[str]]):
  71. if isinstance(log_qnames, str):
  72. log_qnames = [log_qnames]
  73. self.log_alias_to_log_qnames[alias] = log_qnames
  74. # register an artifact name
  75. def register_artifact_name(
  76. self, name, description, visible, off_by_default, log_format
  77. ):
  78. self.artifact_names.add(name)
  79. if visible:
  80. self.visible_artifacts.add(name)
  81. self.artifact_descriptions[name] = description
  82. # if off by default, don't enable it
  83. # when log_name's log_level is set to DEBUG
  84. if off_by_default:
  85. self.off_by_default_artifact_names.add(name)
  86. if log_format is not None:
  87. self.artifact_log_formatters[name] = logging.Formatter(log_format)
  88. # register the qualified name of an artifact log
  89. # this is needed to know which logs need to be reset
  90. # whenever the log_state is changed
  91. def register_artifact_log(self, artifact_log_qname):
  92. self.artifact_log_qnames.add(artifact_log_qname)
  93. def register_child_log(self, log_qname):
  94. self.child_log_qnames.add(log_qname)
  95. # flattens all the qnames together (TODO: consider memoizing?)
  96. def get_log_qnames(self) -> Set[str]:
  97. return {
  98. qname
  99. for qnames in self.log_alias_to_log_qnames.values()
  100. for qname in qnames
  101. }
  102. def get_artifact_log_qnames(self):
  103. return set(self.artifact_log_qnames)
  104. def get_child_log_qnames(self):
  105. return set(self.child_log_qnames)
  106. def is_off_by_default(self, artifact_qname):
  107. return artifact_qname in self.off_by_default_artifact_names
  108. @dataclass
  109. class LogState:
  110. # qualified log names -> currently set log level
  111. log_qname_to_level: Dict[str, str] = field(default_factory=dict)
  112. # the set of currently enabled artifacts
  113. artifact_names: Set[str] = field(default_factory=set)
  114. def enable_artifact(self, artifact_name):
  115. self.artifact_names.add(artifact_name)
  116. def is_artifact_enabled(self, name):
  117. return name in self.artifact_names
  118. def enable_log(self, log_qnames, log_level):
  119. if isinstance(log_qnames, str):
  120. log_qnames = [log_qnames]
  121. for log_qname in log_qnames:
  122. self.log_qname_to_level[log_qname] = log_level
  123. def get_log_level_pairs(self):
  124. """Returns all qualified module names for which the user requested
  125. explicit logging settings.
  126. .. warning:
  127. This function used to return all loggers, regardless of whether
  128. or not the user specified them or not; it now only returns logs
  129. which were explicitly mentioned by the user (and torch, which
  130. always is implicitly requested when we initialize our logging
  131. subsystem.)
  132. """
  133. return self.log_qname_to_level.items()
  134. def clear(self):
  135. self.log_qname_to_level.clear()
  136. self.artifact_names.clear()
  137. log_registry = LogRegistry()
  138. log_state = LogState()
  139. # sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
  140. DEFAULT_LOGGING = {
  141. "dynamo": logging.DEBUG,
  142. "aot": logging.DEBUG,
  143. "inductor": logging.DEBUG,
  144. "ddp_graphs": True,
  145. "graph_breaks": True,
  146. "guards": True,
  147. "recompiles": True,
  148. "dynamic": logging.INFO,
  149. }
  150. def set_logs(
  151. *,
  152. all: Optional[int] = None,
  153. dynamo: Optional[int] = None,
  154. aot: Optional[int] = None,
  155. autograd: Optional[int] = None,
  156. dynamic: Optional[int] = None,
  157. inductor: Optional[int] = None,
  158. distributed: Optional[int] = None,
  159. dist_c10d: Optional[int] = None,
  160. dist_ddp: Optional[int] = None,
  161. dist_fsdp: Optional[int] = None,
  162. onnx: Optional[int] = None,
  163. bytecode: bool = False,
  164. aot_graphs: bool = False,
  165. aot_joint_graph: bool = False,
  166. ddp_graphs: bool = False,
  167. graph: bool = False,
  168. graph_code: bool = False,
  169. graph_breaks: bool = False,
  170. graph_sizes: bool = False,
  171. guards: bool = False,
  172. recompiles: bool = False,
  173. recompiles_verbose: bool = False,
  174. trace_source: bool = False,
  175. trace_call: bool = False,
  176. trace_bytecode: bool = False,
  177. output_code: bool = False,
  178. kernel_code: bool = False,
  179. schedule: bool = False,
  180. perf_hints: bool = False,
  181. post_grad_graphs: bool = False,
  182. onnx_diagnostics: bool = False,
  183. fusion: bool = False,
  184. overlap: bool = False,
  185. export: Optional[int] = None,
  186. modules: Optional[Dict[str, Union[int, bool]]] = None,
  187. cudagraphs: bool = False,
  188. sym_node: bool = False,
  189. compiled_autograd_verbose: bool = False,
  190. ):
  191. """
  192. Sets the log level for individual components and toggles individual log
  193. artifact types.
  194. .. warning:: This feature is a prototype and may have compatibility
  195. breaking changes in the future.
  196. .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
  197. over this function, so if it was set, this function does nothing.
  198. A component is a set of related features in PyTorch. All of the log
  199. messages emitted from a given component have their own log levels. If the
  200. log level of a particular message has priority greater than or equal to its
  201. component's log level setting, it is emitted. Otherwise, it is suppressed.
  202. This allows you to, for instance, silence large groups of log messages that
  203. are not relevant to you and increase verbosity of logs for components that
  204. are relevant. The expected log level values, ordered from highest to lowest
  205. priority, are:
  206. * ``logging.CRITICAL``
  207. * ``logging.ERROR``
  208. * ``logging.WARNING``
  209. * ``logging.INFO``
  210. * ``logging.DEBUG``
  211. * ``logging.NOTSET``
  212. See documentation for the Python ``logging`` module for more information on
  213. log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
  214. An artifact is a particular type of log message. Each artifact is assigned
  215. to a parent component. A component can emit many different kinds of
  216. artifacts. In general, an artifact is emitted if either its corresponding
  217. setting in the argument list below is turned on or if its parent component
  218. is set to a log level less than or equal to the log level of the artifact.
  219. Keyword args:
  220. all (:class:`Optional[int]`):
  221. The default log level for all components. Default: ``logging.WARN``
  222. dynamo (:class:`Optional[int]`):
  223. The log level for the TorchDynamo component. Default: ``logging.WARN``
  224. aot (:class:`Optional[int]`):
  225. The log level for the AOTAutograd component. Default: ``logging.WARN``
  226. autograd (:class:`Optional[int]`):
  227. The log level for autograd. Default: ``logging.WARN``
  228. inductor (:class:`Optional[int]`):
  229. The log level for the TorchInductor component. Default: ``logging.WARN``
  230. dynamic (:class:`Optional[int]`):
  231. The log level for dynamic shapes. Default: ``logging.WARN``
  232. distributed (:class:`Optional[int]`):
  233. Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
  234. Default: ``logging.WARN``
  235. dist_c10d (:class:`Optional[int]`):
  236. Whether to log c10d communication operations related debug info in PyTorch Distributed components.
  237. Default: ``logging.WARN``
  238. dist_ddp (:class:`Optional[int]`):
  239. Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
  240. Default: ``logging.WARN``
  241. dist_fsdp (:class:`Optional[int]`):
  242. Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
  243. Default: ``logging.WARN``
  244. onnx (:class:`Optional[int]`):
  245. The log level for the ONNX exporter component. Default: ``logging.WARN``
  246. bytecode (:class:`bool`):
  247. Whether to emit the original and generated bytecode from TorchDynamo.
  248. Default: ``False``
  249. aot_graphs (:class:`bool`):
  250. Whether to emit the graphs generated by AOTAutograd. Default: ``False``
  251. aot_joint_graph (:class:`bool`):
  252. Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
  253. inductor (:class:`Optional[int]`):
  254. Whether to log information from inductor cudagraphs. Default: ``logging.WARN``
  255. ddp_graphs (:class:`bool`):
  256. Whether to emit graphs generated by DDPOptimizer. Default: ``False``
  257. graph (:class:`bool`):
  258. Whether to emit the graph captured by TorchDynamo in tabular format.
  259. Default: ``False``
  260. graph_code (:class:`bool`):
  261. Whether to emit the python source of the graph captured by TorchDynamo.
  262. Default: ``False``
  263. graph_breaks (:class:`bool`):
  264. Whether to emit the graph breaks encountered by TorchDynamo.
  265. Default: ``False``
  266. graph_sizes (:class:`bool`):
  267. Whether to emit tensor sizes of the graph captured by TorchDynamo.
  268. Default: ``False``
  269. guards (:class:`bool`):
  270. Whether to emit the guards generated by TorchDynamo for each compiled
  271. function. Default: ``False``
  272. recompiles (:class:`bool`):
  273. Whether to emit a guard failure reason and message every time
  274. TorchDynamo recompiles a function. Default: ``False``
  275. recompiles_verbose (:class:`bool`):
  276. Whether to emit all guard failure reasons when TorchDynamo recompiles
  277. a function, even those that are not actually run. Default: ``False``
  278. trace_source (:class:`bool`):
  279. Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
  280. trace_call (:class:`bool`):
  281. Whether to emit detailed line location when TorchDynamo creates an FX node
  282. corresponding to function call. Python 3.11+ only. Default: ``False``
  283. trace_bytecode (:class:`bool`):
  284. Whether to emit bytecode instructions and traced stack state as TorchDynamo
  285. traces bytecode. Default: ``False``
  286. output_code (:class:`bool`):
  287. Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
  288. kernel_code (:class:`bool`):
  289. Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
  290. schedule (:class:`bool`):
  291. Whether to emit the TorchInductor schedule. Default: ``False``
  292. perf_hints (:class:`bool`):
  293. Whether to emit the TorchInductor perf hints. Default: ``False``
  294. post_grad_graphs (:class:`bool`):
  295. Whether to emit the graphs generated by after post grad passes. Default: ``False``
  296. onnx_diagnostics (:class:`bool`):
  297. Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
  298. fusion (:class:`bool`):
  299. Whether to emit detailed Inductor fusion decisions. Default: ``False``
  300. overlap (:class:`bool`):
  301. Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
  302. sym_node (:class:`bool`):
  303. Whether to emit debug info for various SymNode opterations. Default: ``False``
  304. export (:class:`Optional[int]`):
  305. The log level for export. Default: ``logging.WARN``
  306. modules (dict):
  307. This argument provides an alternate way to specify the above log
  308. component and artifact settings, in the format of a keyword args
  309. dictionary given as a single argument. There are two cases
  310. where this is useful (1) if a new log component or artifact has
  311. been registered but a keyword argument for it has not been added
  312. to this function and (2) if the log level for an unregistered module
  313. needs to be set. This can be done by providing the fully-qualified module
  314. name as the key, with the log level as the value. Default: ``None``
  315. Example::
  316. >>> # xdoctest: +SKIP
  317. >>> import logging
  318. # The following changes the "dynamo" component to emit DEBUG-level
  319. # logs, and to emit "graph_code" artifacts.
  320. >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
  321. # The following enables the logs for a different module
  322. >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
  323. """
  324. # ignore if env var is set
  325. if LOG_ENV_VAR in os.environ:
  326. log.warning(
  327. "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
  328. )
  329. return
  330. log_state.clear()
  331. modules = modules or {}
  332. def _set_logs(**kwargs):
  333. for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
  334. if val is None:
  335. continue
  336. if log_registry.is_artifact(alias):
  337. if not isinstance(val, bool):
  338. raise ValueError(
  339. f"Expected bool to enable artifact {alias}, received {val}"
  340. )
  341. if val:
  342. log_state.enable_artifact(alias)
  343. elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
  344. if val not in logging._levelToName:
  345. raise ValueError(
  346. f"Unrecognized log level for log {alias}: {val}, valid level values "
  347. f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
  348. )
  349. log_state.enable_log(
  350. log_registry.log_alias_to_log_qnames.get(alias, alias), val
  351. )
  352. else:
  353. raise ValueError(
  354. f"Unrecognized log or artifact name passed to set_logs: {alias}"
  355. )
  356. _init_logs()
  357. _set_logs(
  358. torch=all,
  359. dynamo=dynamo,
  360. aot=aot,
  361. autograd=autograd,
  362. inductor=inductor,
  363. dynamic=dynamic,
  364. bytecode=bytecode,
  365. aot_graphs=aot_graphs,
  366. aot_joint_graph=aot_joint_graph,
  367. ddp_graphs=ddp_graphs,
  368. distributed=distributed,
  369. dist_c10d=dist_c10d,
  370. dist_ddp=dist_ddp,
  371. dist_fsdp=dist_fsdp,
  372. graph=graph,
  373. graph_code=graph_code,
  374. graph_breaks=graph_breaks,
  375. graph_sizes=graph_sizes,
  376. guards=guards,
  377. recompiles=recompiles,
  378. recompiles_verbose=recompiles_verbose,
  379. trace_source=trace_source,
  380. trace_call=trace_call,
  381. trace_bytecode=trace_bytecode,
  382. output_code=output_code,
  383. kernel_code=kernel_code,
  384. schedule=schedule,
  385. perf_hints=perf_hints,
  386. post_grad_graphs=post_grad_graphs,
  387. onnx=onnx,
  388. onnx_diagnostics=onnx_diagnostics,
  389. fusion=fusion,
  390. overlap=overlap,
  391. sym_node=sym_node,
  392. export=export,
  393. cudagraphs=cudagraphs,
  394. compiled_autograd_verbose=compiled_autograd_verbose,
  395. )
  396. def get_loggers():
  397. """
  398. Returns: a list of all registered loggers
  399. """
  400. return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
  401. def register_log(setting_name, log_name):
  402. """
  403. Enables a log to be controlled by the env var and user API with the setting_name
  404. Args:
  405. setting_name: the shorthand name used in the env var and user API
  406. log_name: the log name that the setting_name is associated with
  407. """
  408. log_registry.register_log(setting_name, log_name)
  409. def register_artifact(
  410. setting_name, description, visible=False, off_by_default=False, log_format=None
  411. ):
  412. """
  413. Enables an artifact to be controlled by the env var and user API with name
  414. Args:
  415. setting_name: the shorthand name used in the env var and user API
  416. description: A description of what this outputs
  417. visible: Whether it gets suggested to users by default
  418. off_by_default: whether this artifact should be logged when the ancestor loggers
  419. are enabled at level DEBUG
  420. """
  421. log_registry.register_artifact_name(
  422. setting_name, description, visible, off_by_default, log_format
  423. )
  424. def getArtifactLogger(module_qname, artifact_name):
  425. if artifact_name not in log_registry.artifact_names:
  426. raise ValueError(
  427. f"Artifact name: {repr(artifact_name)} not registered,"
  428. f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
  429. )
  430. qname = module_qname + f".__{artifact_name}"
  431. log = logging.getLogger(qname)
  432. log.artifact_name = artifact_name # type: ignore[attr-defined]
  433. log_registry.register_artifact_log(qname)
  434. configure_artifact_log(log)
  435. return log
  436. INCR_VERBOSITY_CHAR = "+"
  437. DECR_VERBOSITY_CHAR = "-"
  438. VERBOSITY_REGEX = (
  439. "("
  440. + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
  441. + "?)"
  442. )
  443. def configure_artifact_log(log):
  444. # If the artifact is off by default, then it should only be logged when explicitly
  445. # enabled; set propagate to False so that this artifact is not propagated
  446. # to its ancestor logger
  447. if log_registry.is_off_by_default(log.artifact_name):
  448. log.propagate = False
  449. # enable artifact logging when explicitly enabled
  450. if log_state.is_artifact_enabled(log.artifact_name):
  451. log.setLevel(logging.DEBUG)
  452. log.propagate = True
  453. # match a comma separated list of loggable names (whitespace allowed after commas)
  454. def _gen_settings_regex():
  455. return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
  456. def _validate_settings(settings):
  457. return re.fullmatch(_gen_settings_regex(), settings) is not None
  458. def help_message(verbose=False):
  459. def pad_to(s, length=30):
  460. assert len(s) <= length
  461. return s + " " * (length - len(s))
  462. if verbose:
  463. printed_artifacts = log_registry.artifact_names
  464. else:
  465. printed_artifacts = log_registry.visible_artifacts
  466. if verbose:
  467. heading = "All registered names"
  468. else:
  469. heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
  470. lines = (
  471. ["all"]
  472. + sorted(log_registry.log_alias_to_log_qnames.keys())
  473. + sorted(
  474. [
  475. f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
  476. for name in printed_artifacts
  477. ]
  478. )
  479. )
  480. setting_info = " " + "\n ".join(lines)
  481. examples = """
  482. Examples:
  483. TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
  484. logging.DEBUG and AOT to logging.INFO
  485. TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
  486. logging.ERROR and TorchInductor to logging.DEBUG
  487. TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
  488. TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
  489. to logging.DEBUG and enable the schedule artifact
  490. TORCH_LOGS="+some.random.module,schedule" will set the log level of
  491. some.random.module to logging.DEBUG and enable the schedule artifact
  492. TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
  493. string will set the output format
  494. Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
  495. "filename" and "name".
  496. TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
  497. well. This is useful when the output is long.
  498. """ # flake8: noqa: B950
  499. msg = f"""
  500. TORCH_LOGS Info
  501. {examples}
  502. {heading}
  503. {setting_info}
  504. """
  505. return msg
  506. def _invalid_settings_err_msg(settings, verbose=False):
  507. valid_settings = ", ".join(
  508. ["all"]
  509. + list(log_registry.log_alias_to_log_qnames.keys())
  510. + list(log_registry.artifact_names)
  511. )
  512. msg = f"""
  513. Invalid log settings: {settings}, must be a comma separated list of fully
  514. qualified module names, registered log names or registered artifact names.
  515. For more info on various settings, try TORCH_LOGS="help"
  516. Valid settings:
  517. {valid_settings}
  518. """
  519. return msg
  520. @functools.lru_cache
  521. def _parse_log_settings(settings):
  522. if settings == "":
  523. return dict()
  524. if settings == "help":
  525. raise ValueError(help_message(verbose=False))
  526. elif settings == "+help":
  527. raise ValueError(help_message(verbose=True))
  528. if not _validate_settings(settings):
  529. raise ValueError(_invalid_settings_err_msg(settings))
  530. settings = re.sub(r"\s+", "", settings)
  531. log_names = settings.split(",")
  532. def get_name_level_pair(name):
  533. clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
  534. clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
  535. if name[0] == INCR_VERBOSITY_CHAR:
  536. level = logging.DEBUG
  537. elif name[0] == DECR_VERBOSITY_CHAR:
  538. level = logging.ERROR
  539. else:
  540. level = logging.INFO
  541. return clean_name, level
  542. log_state = LogState()
  543. for name in log_names:
  544. name, level = get_name_level_pair(name)
  545. if name == "all":
  546. name = "torch"
  547. if log_registry.is_log(name):
  548. assert level is not None
  549. log_qnames = log_registry.log_alias_to_log_qnames[name]
  550. log_state.enable_log(log_qnames, level)
  551. elif log_registry.is_artifact(name):
  552. log_state.enable_artifact(name)
  553. elif _is_valid_module(name):
  554. if not _has_registered_parent(name):
  555. log_registry.register_log(name, name)
  556. else:
  557. log_registry.register_child_log(name)
  558. log_state.enable_log(name, level)
  559. else:
  560. raise ValueError(_invalid_settings_err_msg(settings))
  561. return log_state
  562. def _is_valid_module(qname):
  563. try:
  564. __import__(qname)
  565. return True
  566. except ImportError:
  567. return False
  568. def _update_log_state_from_env():
  569. global log_state
  570. log_setting = os.environ.get(LOG_ENV_VAR, None)
  571. if log_setting is not None:
  572. log_state = _parse_log_settings(log_setting)
  573. def _has_registered_parent(log_qname):
  574. cur_log = logging.getLogger(log_qname)
  575. registered_log_qnames = log_registry.get_log_qnames()
  576. while cur_log.parent:
  577. if cur_log.name in registered_log_qnames:
  578. return True
  579. cur_log = cur_log.parent
  580. return False
  581. # apply custom formats to artifacts when necessary
  582. class TorchLogsFormatter(logging.Formatter):
  583. def __init__(self, *, trace: bool = False):
  584. super().__init__()
  585. self._is_trace = trace
  586. def format(self, record):
  587. artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
  588. if artifact_name is not None:
  589. artifact_formatter = log_registry.artifact_log_formatters.get(
  590. artifact_name, None
  591. )
  592. if artifact_formatter is not None:
  593. return artifact_formatter.format(record)
  594. record.message = record.getMessage()
  595. record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
  596. # exception handling - copied from logging.Formatter.format
  597. s = record.message
  598. if record.exc_info:
  599. # Cache the traceback text to avoid converting it multiple times
  600. # (it's constant anyway)
  601. if not record.exc_text:
  602. record.exc_text = self.formatException(record.exc_info)
  603. if record.exc_text:
  604. if s[-1:] != "\n":
  605. s = s + "\n"
  606. s = s + record.exc_text
  607. if record.stack_info:
  608. if s[-1:] != "\n":
  609. s = s + "\n"
  610. s = s + self.formatStack(record.stack_info)
  611. record.rankprefix = ""
  612. if not self._is_trace and dist.is_available() and dist.is_initialized():
  613. record.rankprefix = f"[rank{dist.get_rank()}]:"
  614. record.traceid = ""
  615. if (
  616. not self._is_trace
  617. and (trace_id := torch._guards.CompileContext.current_trace_id())
  618. is not None
  619. ):
  620. record.traceid = f" [{trace_id}]"
  621. glog_level_to_abbr = {
  622. "DEBUG": "V", # V is for VERBOSE in glog
  623. "INFO": "I",
  624. "WARNING": "W",
  625. "ERROR": "E",
  626. "CRITICAL": "C",
  627. }
  628. shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
  629. record.artifactprefix = ""
  630. if artifact_name is not None:
  631. record.artifactprefix = f" [__{artifact_name}]"
  632. prefix = (
  633. f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} "
  634. f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:"
  635. f"{record.lineno}]{record.traceid}{record.artifactprefix}"
  636. )
  637. if self._is_trace:
  638. assert s == ""
  639. try:
  640. r = f"{prefix} {json.dumps(record.metadata)}"
  641. except TypeError:
  642. log.warning("failing metadata: %r", record.metadata)
  643. raise
  644. if record.payload is not None:
  645. r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
  646. return r
  647. else:
  648. lines = s.split("\n")
  649. return "\n".join(f"{prefix} {l}" for l in lines)
  650. def _default_formatter():
  651. fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
  652. if fmt is None:
  653. return TorchLogsFormatter()
  654. else:
  655. if fmt in ("short", "basic"):
  656. fmt = logging.BASIC_FORMAT
  657. return logging.Formatter(fmt)
  658. DEFAULT_FORMATTER = _default_formatter()
  659. def _setup_handlers(create_handler_fn, log):
  660. debug_handler = _track_handler(create_handler_fn())
  661. debug_handler.setFormatter(DEFAULT_FORMATTER)
  662. debug_handler.setLevel(logging.DEBUG)
  663. log.addHandler(debug_handler)
  664. handlers = WeakSet() # type: ignore[var-annotated]
  665. # mark handlers that we've created
  666. # so we don't modify user handlers
  667. def _track_handler(handler):
  668. handlers.add(handler)
  669. return handler
  670. def _is_torch_handler(handler):
  671. return handler in handlers
  672. # clears all torch handlers on specified loggers
  673. def _clear_handlers(log):
  674. to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
  675. for handler in to_remove:
  676. log.removeHandler(handler)
  677. def _reset_logs():
  678. # reset all registered logs
  679. for log_qname in log_registry.get_log_qnames():
  680. log = logging.getLogger(log_qname)
  681. log.setLevel(logging.WARNING)
  682. log.propagate = False
  683. _clear_handlers(log)
  684. # reset all artifact and child logs
  685. for artifact_log_qname in itertools.chain(
  686. log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
  687. ):
  688. log = logging.getLogger(artifact_log_qname)
  689. log.setLevel(logging.NOTSET)
  690. log.propagate = True
  691. trace_log.propagate = False
  692. _clear_handlers(trace_log)
  693. def _get_log_state():
  694. return log_state
  695. def _set_log_state(state):
  696. global log_state
  697. log_state = state
  698. def _init_logs(log_file_name=None):
  699. _reset_logs()
  700. _update_log_state_from_env()
  701. out = os.environ.get(LOG_OUT_ENV_VAR, None)
  702. if out is not None:
  703. log_file_name = out
  704. # First, reset all known (registered) loggers to NOTSET, so that they
  705. # respect their parent log level
  706. for log_qname in log_registry.get_log_qnames():
  707. # But not the top level torch level: this defaults to WARNING so
  708. # that our log messages don't leak to the lower levels
  709. if log_qname == "torch":
  710. continue
  711. log = logging.getLogger(log_qname)
  712. log.setLevel(logging.NOTSET)
  713. # Now, for all loggers which the user requested to have non-standard
  714. # logging behavior, modify their log levels
  715. for log_qname, level in log_state.get_log_level_pairs():
  716. log = logging.getLogger(log_qname)
  717. log.setLevel(level)
  718. # Finally, setup handlers for all registered loggers
  719. for log_qname in log_registry.get_log_qnames():
  720. log = logging.getLogger(log_qname)
  721. _setup_handlers(
  722. logging.StreamHandler,
  723. log,
  724. )
  725. if log_file_name is not None:
  726. _setup_handlers(
  727. lambda: logging.FileHandler(log_file_name),
  728. log,
  729. )
  730. # configure artifact loggers, note: this must happen last
  731. # since the levels of ancestor loggers are taken into account
  732. for artifact_log_qname in log_registry.get_artifact_log_qnames():
  733. log = logging.getLogger(artifact_log_qname)
  734. configure_artifact_log(log)
  735. # Setup handler for the special trace_log, with different default
  736. # configuration
  737. trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
  738. # This handler may remove itself if trace_dir_name is None and we are not
  739. # actually in an FB environment. This allows us to defer actually
  740. # initializing it until we actually need to log anything. This is
  741. # important because JK initializes a C++ singleton, which will pork our
  742. # process if we subsequently fork.
  743. handler = LazyTraceHandler(trace_dir_name)
  744. # This log is ALWAYS at debug level. We will additionally test if there
  745. # are any handlers before deciding to actually call logging on this. Do
  746. # not manually call
  747. trace_log.setLevel(logging.DEBUG)
  748. trace_log_handler = _track_handler(handler)
  749. trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
  750. trace_log.addHandler(trace_log_handler)
  751. class LazyTraceHandler(logging.StreamHandler):
  752. """Like FileHandler, but the file is allocated lazily only upon the first log message"""
  753. def __init__(self, root_dir: Optional[str]):
  754. # This is implemented in the same way that delay is implemented on
  755. # FileHandler
  756. self.root_dir = root_dir
  757. logging.Handler.__init__(self)
  758. self.stream = None
  759. self._builtin_open = open
  760. # cloned from FileHandler in cpython
  761. def close(self):
  762. self.acquire()
  763. try:
  764. try:
  765. if self.stream:
  766. try:
  767. self.flush()
  768. finally:
  769. stream = self.stream
  770. self.stream = None
  771. if hasattr(stream, "close"):
  772. stream.close()
  773. finally:
  774. # Issue #19523: call unconditionally to
  775. # prevent a handler leak when delay is set
  776. # Also see Issue #42378: we also rely on
  777. # self._closed being set to True there
  778. logging.StreamHandler.close(self)
  779. finally:
  780. self.release()
  781. def emit(self, record):
  782. if self.stream is None:
  783. ok = False
  784. if self.root_dir is None:
  785. TRACE_LOG_DIR = "/logs"
  786. open_func = self._builtin_open
  787. import torch.version as torch_version
  788. if hasattr(torch_version, "git_version"):
  789. log.info("LazyTraceHandler: disabled because not fbcode")
  790. elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
  791. log.info(
  792. "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
  793. )
  794. elif not os.path.exists(TRACE_LOG_DIR):
  795. log.info(
  796. "LazyTraceHandler: disabled because %s does not exist",
  797. TRACE_LOG_DIR,
  798. )
  799. elif not os.access(TRACE_LOG_DIR, os.W_OK):
  800. log.info(
  801. "LazyTraceHandler: disabled because %s is not writeable",
  802. TRACE_LOG_DIR,
  803. )
  804. else:
  805. self.root_dir = TRACE_LOG_DIR
  806. if self.root_dir is not None:
  807. os.makedirs(self.root_dir, exist_ok=True)
  808. ranksuffix = ""
  809. if dist.is_available() and dist.is_initialized():
  810. ranksuffix = f"rank_{dist.get_rank()}_"
  811. self.stream = tempfile.NamedTemporaryFile(
  812. mode="w+",
  813. suffix=".log",
  814. prefix=f"dedicated_log_torch_trace_{ranksuffix}",
  815. dir=self.root_dir,
  816. delete=False,
  817. )
  818. log.info("LazyTraceHandler: logging to %s", self.stream.name)
  819. else:
  820. # We go poof, remove and no-op
  821. trace_log.removeHandler(self)
  822. return
  823. if self.stream:
  824. super().emit(record)
  825. @functools.lru_cache(None)
  826. def warning_once(logger_obj, *args, **kwargs):
  827. """
  828. This function is similar to `logger.warning()`, but will emit the warning with the same message only once
  829. Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
  830. The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
  831. another type of cache that includes the caller frame information in the hashing function.
  832. """
  833. logger_obj.warning(*args, **kwargs)
  834. class LazyString:
  835. def __init__(self, func, *args, **kwargs):
  836. self.func = func
  837. self.args = args
  838. self.kwargs = kwargs
  839. def __str__(self):
  840. return self.func(*self.args, **self.kwargs)
  841. def trace_structured(
  842. name: str,
  843. # NB: metadata expected to be dict so adding more info is forward compatible
  844. # Tuple[str, int] is a special case for string interning
  845. metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
  846. *,
  847. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  848. suppress_context: bool = False,
  849. ):
  850. """
  851. metadata is an arbitrary JSON compatible struct, but it's expected to not be
  852. too long (e.g., less than 1MB)
  853. payload is an arbitrary string, which can be arbitrarily long (but expected to have
  854. newlines so no lines are too long)
  855. """
  856. assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"]
  857. assert callable(
  858. metadata_fn
  859. ), f"metadata_fn should be callable, but got {type(metadata_fn)}"
  860. assert callable(
  861. payload_fn
  862. ), f"payload_fn should be callable, but got {type(payload_fn)}"
  863. # trace_log never propagates and is ALWAYS DEBUG, so also check that there
  864. # are handlers instead of checking the log level
  865. if trace_log.handlers:
  866. record: Dict[str, object] = {}
  867. record[name] = metadata_fn()
  868. if not suppress_context:
  869. # TODO: Actually, the rank probably should just be emitted once at
  870. # the top, and not repeatedly spammed in all the logs, since it
  871. # never changes and we assume no interleaving
  872. if dist.is_available() and dist.is_initialized():
  873. record["rank"] = dist.get_rank()
  874. if (
  875. trace_id := torch._guards.CompileContext.current_trace_id()
  876. ) is not None:
  877. record["frame_id"] = trace_id.compile_id.frame_id
  878. record["frame_compile_id"] = trace_id.compile_id.frame_compile_id
  879. record["attempt"] = trace_id.attempt
  880. else:
  881. # Record the stack of the log call to better diagnose why we
  882. # don't have a frame id for it
  883. record["stack"] = torch._logging.structured.from_traceback(
  884. CapturedTraceback.extract(skip=1).summary()
  885. )
  886. payload = payload_fn()
  887. if payload is not None:
  888. if not isinstance(payload, str):
  889. if isinstance(payload, list):
  890. # special case to look better
  891. payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
  892. else:
  893. # force newlines so we are unlikely to overflow line limit
  894. payload = json.dumps(payload, indent=0)
  895. h = hashlib.md5()
  896. h.update(payload.encode("utf-8"))
  897. record["has_payload"] = h.hexdigest()
  898. trace_log.debug(
  899. "", extra={"metadata": record, "payload": payload}, stacklevel=2
  900. )
  901. import torch._guards
  902. import torch._utils_internal
  903. import torch.distributed as dist