| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112 |
- # mypy: allow-untyped-defs
- import functools
- import hashlib
- import itertools
- import json
- import logging
- import os
- import os.path
- import re
- import tempfile
- from dataclasses import dataclass, field
- from importlib import __import__
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
- from weakref import WeakSet
- import torch._logging.structured
- from torch.utils._traceback import CapturedTraceback
- log = logging.getLogger(__name__)
- # This is a synthetic logger which doesn't correspond to an actual logger,
- # but handles all of our "tracing" logging, which is structured and doesn't go
- # to stderr but always goes to a dedicated log file. We don't put these
- # loggers in the classic module hierarchy, because we don't want a suppression
- # of logs to also cause a trace to get suppressed (traces typically are not
- # collected, unless we are in prod, in which case they always are collected.)
- #
- # TODO: Maybe we should allow for some sub-hierarchy so you can control which
- # traces you want to collect, for performance reasons.
- #
- # See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
- trace_log = logging.getLogger("torch.__trace")
- DEFAULT_LOG_LEVEL = logging.WARNING
- LOG_ENV_VAR = "TORCH_LOGS"
- LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
- LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
- TRACE_ENV_VAR = "TORCH_TRACE"
- @dataclass
- class LogRegistry:
- # shorthand name to log qualified name
- # Note: this only contains loggers registered
- # from register_log
- # e.g. "dynamo" -> "torch._dynamo"
- log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
- # artifact logger qualified names,
- # this is populated lazily, as calls to getArtifactLogger
- # currently formatted as <module>.__<artifact_name>
- # e.g. "torch._dynamo.convert_frame.__guards"
- artifact_log_qnames: Set[str] = field(default_factory=set)
- # child logs of registered logs if specified via open
- # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
- # these need to be tracked so their levels can be reset properly
- # e.g. "torch._dynamo.output_graph"
- child_log_qnames: Set[str] = field(default_factory=set)
- # artifact names, populated by register_artifact
- # e.g. "guards"
- artifact_names: Set[str] = field(default_factory=set)
- # Artifacts that should be visible by default in the error message
- visible_artifacts: Set[str] = field(default_factory=set)
- # A short description of each artifact
- artifact_descriptions: Dict[str, str] = field(default_factory=dict)
- # artifacts which are not displayed unless explicitly named in the
- # settings. Ex. output_code is NOT displayed even if the inductor
- # log level is set to DEBUG. It must be explicitly named in the settings
- off_by_default_artifact_names: Set[str] = field(default_factory=set)
- # logging format string for artifacts
- artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
- def is_artifact(self, name):
- return name in self.artifact_names
- def is_log(self, alias):
- return alias in self.log_alias_to_log_qnames
- # register a log with an alias
- def register_log(self, alias, log_qnames: Union[str, List[str]]):
- if isinstance(log_qnames, str):
- log_qnames = [log_qnames]
- self.log_alias_to_log_qnames[alias] = log_qnames
- # register an artifact name
- def register_artifact_name(
- self, name, description, visible, off_by_default, log_format
- ):
- self.artifact_names.add(name)
- if visible:
- self.visible_artifacts.add(name)
- self.artifact_descriptions[name] = description
- # if off by default, don't enable it
- # when log_name's log_level is set to DEBUG
- if off_by_default:
- self.off_by_default_artifact_names.add(name)
- if log_format is not None:
- self.artifact_log_formatters[name] = logging.Formatter(log_format)
- # register the qualified name of an artifact log
- # this is needed to know which logs need to be reset
- # whenever the log_state is changed
- def register_artifact_log(self, artifact_log_qname):
- self.artifact_log_qnames.add(artifact_log_qname)
- def register_child_log(self, log_qname):
- self.child_log_qnames.add(log_qname)
- # flattens all the qnames together (TODO: consider memoizing?)
- def get_log_qnames(self) -> Set[str]:
- return {
- qname
- for qnames in self.log_alias_to_log_qnames.values()
- for qname in qnames
- }
- def get_artifact_log_qnames(self):
- return set(self.artifact_log_qnames)
- def get_child_log_qnames(self):
- return set(self.child_log_qnames)
- def is_off_by_default(self, artifact_qname):
- return artifact_qname in self.off_by_default_artifact_names
- @dataclass
- class LogState:
- # qualified log names -> currently set log level
- log_qname_to_level: Dict[str, str] = field(default_factory=dict)
- # the set of currently enabled artifacts
- artifact_names: Set[str] = field(default_factory=set)
- def enable_artifact(self, artifact_name):
- self.artifact_names.add(artifact_name)
- def is_artifact_enabled(self, name):
- return name in self.artifact_names
- def enable_log(self, log_qnames, log_level):
- if isinstance(log_qnames, str):
- log_qnames = [log_qnames]
- for log_qname in log_qnames:
- self.log_qname_to_level[log_qname] = log_level
- def get_log_level_pairs(self):
- """Returns all qualified module names for which the user requested
- explicit logging settings.
- .. warning:
- This function used to return all loggers, regardless of whether
- or not the user specified them or not; it now only returns logs
- which were explicitly mentioned by the user (and torch, which
- always is implicitly requested when we initialize our logging
- subsystem.)
- """
- return self.log_qname_to_level.items()
- def clear(self):
- self.log_qname_to_level.clear()
- self.artifact_names.clear()
- log_registry = LogRegistry()
- log_state = LogState()
- # sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
- DEFAULT_LOGGING = {
- "dynamo": logging.DEBUG,
- "aot": logging.DEBUG,
- "inductor": logging.DEBUG,
- "ddp_graphs": True,
- "graph_breaks": True,
- "guards": True,
- "recompiles": True,
- "dynamic": logging.INFO,
- }
- def set_logs(
- *,
- all: Optional[int] = None,
- dynamo: Optional[int] = None,
- aot: Optional[int] = None,
- autograd: Optional[int] = None,
- dynamic: Optional[int] = None,
- inductor: Optional[int] = None,
- distributed: Optional[int] = None,
- dist_c10d: Optional[int] = None,
- dist_ddp: Optional[int] = None,
- dist_fsdp: Optional[int] = None,
- onnx: Optional[int] = None,
- bytecode: bool = False,
- aot_graphs: bool = False,
- aot_joint_graph: bool = False,
- ddp_graphs: bool = False,
- graph: bool = False,
- graph_code: bool = False,
- graph_breaks: bool = False,
- graph_sizes: bool = False,
- guards: bool = False,
- recompiles: bool = False,
- recompiles_verbose: bool = False,
- trace_source: bool = False,
- trace_call: bool = False,
- trace_bytecode: bool = False,
- output_code: bool = False,
- kernel_code: bool = False,
- schedule: bool = False,
- perf_hints: bool = False,
- post_grad_graphs: bool = False,
- onnx_diagnostics: bool = False,
- fusion: bool = False,
- overlap: bool = False,
- export: Optional[int] = None,
- modules: Optional[Dict[str, Union[int, bool]]] = None,
- cudagraphs: bool = False,
- sym_node: bool = False,
- compiled_autograd_verbose: bool = False,
- ):
- """
- Sets the log level for individual components and toggles individual log
- artifact types.
- .. warning:: This feature is a prototype and may have compatibility
- breaking changes in the future.
- .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
- over this function, so if it was set, this function does nothing.
- A component is a set of related features in PyTorch. All of the log
- messages emitted from a given component have their own log levels. If the
- log level of a particular message has priority greater than or equal to its
- component's log level setting, it is emitted. Otherwise, it is suppressed.
- This allows you to, for instance, silence large groups of log messages that
- are not relevant to you and increase verbosity of logs for components that
- are relevant. The expected log level values, ordered from highest to lowest
- priority, are:
- * ``logging.CRITICAL``
- * ``logging.ERROR``
- * ``logging.WARNING``
- * ``logging.INFO``
- * ``logging.DEBUG``
- * ``logging.NOTSET``
- See documentation for the Python ``logging`` module for more information on
- log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
- An artifact is a particular type of log message. Each artifact is assigned
- to a parent component. A component can emit many different kinds of
- artifacts. In general, an artifact is emitted if either its corresponding
- setting in the argument list below is turned on or if its parent component
- is set to a log level less than or equal to the log level of the artifact.
- Keyword args:
- all (:class:`Optional[int]`):
- The default log level for all components. Default: ``logging.WARN``
- dynamo (:class:`Optional[int]`):
- The log level for the TorchDynamo component. Default: ``logging.WARN``
- aot (:class:`Optional[int]`):
- The log level for the AOTAutograd component. Default: ``logging.WARN``
- autograd (:class:`Optional[int]`):
- The log level for autograd. Default: ``logging.WARN``
- inductor (:class:`Optional[int]`):
- The log level for the TorchInductor component. Default: ``logging.WARN``
- dynamic (:class:`Optional[int]`):
- The log level for dynamic shapes. Default: ``logging.WARN``
- distributed (:class:`Optional[int]`):
- Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
- Default: ``logging.WARN``
- dist_c10d (:class:`Optional[int]`):
- Whether to log c10d communication operations related debug info in PyTorch Distributed components.
- Default: ``logging.WARN``
- dist_ddp (:class:`Optional[int]`):
- Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
- Default: ``logging.WARN``
- dist_fsdp (:class:`Optional[int]`):
- Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
- Default: ``logging.WARN``
- onnx (:class:`Optional[int]`):
- The log level for the ONNX exporter component. Default: ``logging.WARN``
- bytecode (:class:`bool`):
- Whether to emit the original and generated bytecode from TorchDynamo.
- Default: ``False``
- aot_graphs (:class:`bool`):
- Whether to emit the graphs generated by AOTAutograd. Default: ``False``
- aot_joint_graph (:class:`bool`):
- Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
- inductor (:class:`Optional[int]`):
- Whether to log information from inductor cudagraphs. Default: ``logging.WARN``
- ddp_graphs (:class:`bool`):
- Whether to emit graphs generated by DDPOptimizer. Default: ``False``
- graph (:class:`bool`):
- Whether to emit the graph captured by TorchDynamo in tabular format.
- Default: ``False``
- graph_code (:class:`bool`):
- Whether to emit the python source of the graph captured by TorchDynamo.
- Default: ``False``
- graph_breaks (:class:`bool`):
- Whether to emit the graph breaks encountered by TorchDynamo.
- Default: ``False``
- graph_sizes (:class:`bool`):
- Whether to emit tensor sizes of the graph captured by TorchDynamo.
- Default: ``False``
- guards (:class:`bool`):
- Whether to emit the guards generated by TorchDynamo for each compiled
- function. Default: ``False``
- recompiles (:class:`bool`):
- Whether to emit a guard failure reason and message every time
- TorchDynamo recompiles a function. Default: ``False``
- recompiles_verbose (:class:`bool`):
- Whether to emit all guard failure reasons when TorchDynamo recompiles
- a function, even those that are not actually run. Default: ``False``
- trace_source (:class:`bool`):
- Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
- trace_call (:class:`bool`):
- Whether to emit detailed line location when TorchDynamo creates an FX node
- corresponding to function call. Python 3.11+ only. Default: ``False``
- trace_bytecode (:class:`bool`):
- Whether to emit bytecode instructions and traced stack state as TorchDynamo
- traces bytecode. Default: ``False``
- output_code (:class:`bool`):
- Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
- kernel_code (:class:`bool`):
- Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
- schedule (:class:`bool`):
- Whether to emit the TorchInductor schedule. Default: ``False``
- perf_hints (:class:`bool`):
- Whether to emit the TorchInductor perf hints. Default: ``False``
- post_grad_graphs (:class:`bool`):
- Whether to emit the graphs generated by after post grad passes. Default: ``False``
- onnx_diagnostics (:class:`bool`):
- Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
- fusion (:class:`bool`):
- Whether to emit detailed Inductor fusion decisions. Default: ``False``
- overlap (:class:`bool`):
- Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
- sym_node (:class:`bool`):
- Whether to emit debug info for various SymNode opterations. Default: ``False``
- export (:class:`Optional[int]`):
- The log level for export. Default: ``logging.WARN``
- modules (dict):
- This argument provides an alternate way to specify the above log
- component and artifact settings, in the format of a keyword args
- dictionary given as a single argument. There are two cases
- where this is useful (1) if a new log component or artifact has
- been registered but a keyword argument for it has not been added
- to this function and (2) if the log level for an unregistered module
- needs to be set. This can be done by providing the fully-qualified module
- name as the key, with the log level as the value. Default: ``None``
- Example::
- >>> # xdoctest: +SKIP
- >>> import logging
- # The following changes the "dynamo" component to emit DEBUG-level
- # logs, and to emit "graph_code" artifacts.
- >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
- # The following enables the logs for a different module
- >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
- """
- # ignore if env var is set
- if LOG_ENV_VAR in os.environ:
- log.warning(
- "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
- )
- return
- log_state.clear()
- modules = modules or {}
- def _set_logs(**kwargs):
- for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
- if val is None:
- continue
- if log_registry.is_artifact(alias):
- if not isinstance(val, bool):
- raise ValueError(
- f"Expected bool to enable artifact {alias}, received {val}"
- )
- if val:
- log_state.enable_artifact(alias)
- elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
- if val not in logging._levelToName:
- raise ValueError(
- f"Unrecognized log level for log {alias}: {val}, valid level values "
- f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
- )
- log_state.enable_log(
- log_registry.log_alias_to_log_qnames.get(alias, alias), val
- )
- else:
- raise ValueError(
- f"Unrecognized log or artifact name passed to set_logs: {alias}"
- )
- _init_logs()
- _set_logs(
- torch=all,
- dynamo=dynamo,
- aot=aot,
- autograd=autograd,
- inductor=inductor,
- dynamic=dynamic,
- bytecode=bytecode,
- aot_graphs=aot_graphs,
- aot_joint_graph=aot_joint_graph,
- ddp_graphs=ddp_graphs,
- distributed=distributed,
- dist_c10d=dist_c10d,
- dist_ddp=dist_ddp,
- dist_fsdp=dist_fsdp,
- graph=graph,
- graph_code=graph_code,
- graph_breaks=graph_breaks,
- graph_sizes=graph_sizes,
- guards=guards,
- recompiles=recompiles,
- recompiles_verbose=recompiles_verbose,
- trace_source=trace_source,
- trace_call=trace_call,
- trace_bytecode=trace_bytecode,
- output_code=output_code,
- kernel_code=kernel_code,
- schedule=schedule,
- perf_hints=perf_hints,
- post_grad_graphs=post_grad_graphs,
- onnx=onnx,
- onnx_diagnostics=onnx_diagnostics,
- fusion=fusion,
- overlap=overlap,
- sym_node=sym_node,
- export=export,
- cudagraphs=cudagraphs,
- compiled_autograd_verbose=compiled_autograd_verbose,
- )
- def get_loggers():
- """
- Returns: a list of all registered loggers
- """
- return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
- def register_log(setting_name, log_name):
- """
- Enables a log to be controlled by the env var and user API with the setting_name
- Args:
- setting_name: the shorthand name used in the env var and user API
- log_name: the log name that the setting_name is associated with
- """
- log_registry.register_log(setting_name, log_name)
- def register_artifact(
- setting_name, description, visible=False, off_by_default=False, log_format=None
- ):
- """
- Enables an artifact to be controlled by the env var and user API with name
- Args:
- setting_name: the shorthand name used in the env var and user API
- description: A description of what this outputs
- visible: Whether it gets suggested to users by default
- off_by_default: whether this artifact should be logged when the ancestor loggers
- are enabled at level DEBUG
- """
- log_registry.register_artifact_name(
- setting_name, description, visible, off_by_default, log_format
- )
- def getArtifactLogger(module_qname, artifact_name):
- if artifact_name not in log_registry.artifact_names:
- raise ValueError(
- f"Artifact name: {repr(artifact_name)} not registered,"
- f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
- )
- qname = module_qname + f".__{artifact_name}"
- log = logging.getLogger(qname)
- log.artifact_name = artifact_name # type: ignore[attr-defined]
- log_registry.register_artifact_log(qname)
- configure_artifact_log(log)
- return log
- INCR_VERBOSITY_CHAR = "+"
- DECR_VERBOSITY_CHAR = "-"
- VERBOSITY_REGEX = (
- "("
- + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
- + "?)"
- )
- def configure_artifact_log(log):
- # If the artifact is off by default, then it should only be logged when explicitly
- # enabled; set propagate to False so that this artifact is not propagated
- # to its ancestor logger
- if log_registry.is_off_by_default(log.artifact_name):
- log.propagate = False
- # enable artifact logging when explicitly enabled
- if log_state.is_artifact_enabled(log.artifact_name):
- log.setLevel(logging.DEBUG)
- log.propagate = True
- # match a comma separated list of loggable names (whitespace allowed after commas)
- def _gen_settings_regex():
- return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
- def _validate_settings(settings):
- return re.fullmatch(_gen_settings_regex(), settings) is not None
- def help_message(verbose=False):
- def pad_to(s, length=30):
- assert len(s) <= length
- return s + " " * (length - len(s))
- if verbose:
- printed_artifacts = log_registry.artifact_names
- else:
- printed_artifacts = log_registry.visible_artifacts
- if verbose:
- heading = "All registered names"
- else:
- heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
- lines = (
- ["all"]
- + sorted(log_registry.log_alias_to_log_qnames.keys())
- + sorted(
- [
- f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
- for name in printed_artifacts
- ]
- )
- )
- setting_info = " " + "\n ".join(lines)
- examples = """
- Examples:
- TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
- logging.DEBUG and AOT to logging.INFO
- TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
- logging.ERROR and TorchInductor to logging.DEBUG
- TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
- TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
- to logging.DEBUG and enable the schedule artifact
- TORCH_LOGS="+some.random.module,schedule" will set the log level of
- some.random.module to logging.DEBUG and enable the schedule artifact
- TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
- string will set the output format
- Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
- "filename" and "name".
- TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
- well. This is useful when the output is long.
- """ # flake8: noqa: B950
- msg = f"""
- TORCH_LOGS Info
- {examples}
- {heading}
- {setting_info}
- """
- return msg
- def _invalid_settings_err_msg(settings, verbose=False):
- valid_settings = ", ".join(
- ["all"]
- + list(log_registry.log_alias_to_log_qnames.keys())
- + list(log_registry.artifact_names)
- )
- msg = f"""
- Invalid log settings: {settings}, must be a comma separated list of fully
- qualified module names, registered log names or registered artifact names.
- For more info on various settings, try TORCH_LOGS="help"
- Valid settings:
- {valid_settings}
- """
- return msg
- @functools.lru_cache
- def _parse_log_settings(settings):
- if settings == "":
- return dict()
- if settings == "help":
- raise ValueError(help_message(verbose=False))
- elif settings == "+help":
- raise ValueError(help_message(verbose=True))
- if not _validate_settings(settings):
- raise ValueError(_invalid_settings_err_msg(settings))
- settings = re.sub(r"\s+", "", settings)
- log_names = settings.split(",")
- def get_name_level_pair(name):
- clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
- clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
- if name[0] == INCR_VERBOSITY_CHAR:
- level = logging.DEBUG
- elif name[0] == DECR_VERBOSITY_CHAR:
- level = logging.ERROR
- else:
- level = logging.INFO
- return clean_name, level
- log_state = LogState()
- for name in log_names:
- name, level = get_name_level_pair(name)
- if name == "all":
- name = "torch"
- if log_registry.is_log(name):
- assert level is not None
- log_qnames = log_registry.log_alias_to_log_qnames[name]
- log_state.enable_log(log_qnames, level)
- elif log_registry.is_artifact(name):
- log_state.enable_artifact(name)
- elif _is_valid_module(name):
- if not _has_registered_parent(name):
- log_registry.register_log(name, name)
- else:
- log_registry.register_child_log(name)
- log_state.enable_log(name, level)
- else:
- raise ValueError(_invalid_settings_err_msg(settings))
- return log_state
- def _is_valid_module(qname):
- try:
- __import__(qname)
- return True
- except ImportError:
- return False
- def _update_log_state_from_env():
- global log_state
- log_setting = os.environ.get(LOG_ENV_VAR, None)
- if log_setting is not None:
- log_state = _parse_log_settings(log_setting)
- def _has_registered_parent(log_qname):
- cur_log = logging.getLogger(log_qname)
- registered_log_qnames = log_registry.get_log_qnames()
- while cur_log.parent:
- if cur_log.name in registered_log_qnames:
- return True
- cur_log = cur_log.parent
- return False
- # apply custom formats to artifacts when necessary
- class TorchLogsFormatter(logging.Formatter):
- def __init__(self, *, trace: bool = False):
- super().__init__()
- self._is_trace = trace
- def format(self, record):
- artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
- if artifact_name is not None:
- artifact_formatter = log_registry.artifact_log_formatters.get(
- artifact_name, None
- )
- if artifact_formatter is not None:
- return artifact_formatter.format(record)
- record.message = record.getMessage()
- record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
- # exception handling - copied from logging.Formatter.format
- s = record.message
- if record.exc_info:
- # Cache the traceback text to avoid converting it multiple times
- # (it's constant anyway)
- if not record.exc_text:
- record.exc_text = self.formatException(record.exc_info)
- if record.exc_text:
- if s[-1:] != "\n":
- s = s + "\n"
- s = s + record.exc_text
- if record.stack_info:
- if s[-1:] != "\n":
- s = s + "\n"
- s = s + self.formatStack(record.stack_info)
- record.rankprefix = ""
- if not self._is_trace and dist.is_available() and dist.is_initialized():
- record.rankprefix = f"[rank{dist.get_rank()}]:"
- record.traceid = ""
- if (
- not self._is_trace
- and (trace_id := torch._guards.CompileContext.current_trace_id())
- is not None
- ):
- record.traceid = f" [{trace_id}]"
- glog_level_to_abbr = {
- "DEBUG": "V", # V is for VERBOSE in glog
- "INFO": "I",
- "WARNING": "W",
- "ERROR": "E",
- "CRITICAL": "C",
- }
- shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
- record.artifactprefix = ""
- if artifact_name is not None:
- record.artifactprefix = f" [__{artifact_name}]"
- prefix = (
- f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} "
- f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:"
- f"{record.lineno}]{record.traceid}{record.artifactprefix}"
- )
- if self._is_trace:
- assert s == ""
- try:
- r = f"{prefix} {json.dumps(record.metadata)}"
- except TypeError:
- log.warning("failing metadata: %r", record.metadata)
- raise
- if record.payload is not None:
- r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
- return r
- else:
- lines = s.split("\n")
- return "\n".join(f"{prefix} {l}" for l in lines)
- def _default_formatter():
- fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
- if fmt is None:
- return TorchLogsFormatter()
- else:
- if fmt in ("short", "basic"):
- fmt = logging.BASIC_FORMAT
- return logging.Formatter(fmt)
- DEFAULT_FORMATTER = _default_formatter()
- def _setup_handlers(create_handler_fn, log):
- debug_handler = _track_handler(create_handler_fn())
- debug_handler.setFormatter(DEFAULT_FORMATTER)
- debug_handler.setLevel(logging.DEBUG)
- log.addHandler(debug_handler)
- handlers = WeakSet() # type: ignore[var-annotated]
- # mark handlers that we've created
- # so we don't modify user handlers
- def _track_handler(handler):
- handlers.add(handler)
- return handler
- def _is_torch_handler(handler):
- return handler in handlers
- # clears all torch handlers on specified loggers
- def _clear_handlers(log):
- to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
- for handler in to_remove:
- log.removeHandler(handler)
- def _reset_logs():
- # reset all registered logs
- for log_qname in log_registry.get_log_qnames():
- log = logging.getLogger(log_qname)
- log.setLevel(logging.WARNING)
- log.propagate = False
- _clear_handlers(log)
- # reset all artifact and child logs
- for artifact_log_qname in itertools.chain(
- log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
- ):
- log = logging.getLogger(artifact_log_qname)
- log.setLevel(logging.NOTSET)
- log.propagate = True
- trace_log.propagate = False
- _clear_handlers(trace_log)
- def _get_log_state():
- return log_state
- def _set_log_state(state):
- global log_state
- log_state = state
- def _init_logs(log_file_name=None):
- _reset_logs()
- _update_log_state_from_env()
- out = os.environ.get(LOG_OUT_ENV_VAR, None)
- if out is not None:
- log_file_name = out
- # First, reset all known (registered) loggers to NOTSET, so that they
- # respect their parent log level
- for log_qname in log_registry.get_log_qnames():
- # But not the top level torch level: this defaults to WARNING so
- # that our log messages don't leak to the lower levels
- if log_qname == "torch":
- continue
- log = logging.getLogger(log_qname)
- log.setLevel(logging.NOTSET)
- # Now, for all loggers which the user requested to have non-standard
- # logging behavior, modify their log levels
- for log_qname, level in log_state.get_log_level_pairs():
- log = logging.getLogger(log_qname)
- log.setLevel(level)
- # Finally, setup handlers for all registered loggers
- for log_qname in log_registry.get_log_qnames():
- log = logging.getLogger(log_qname)
- _setup_handlers(
- logging.StreamHandler,
- log,
- )
- if log_file_name is not None:
- _setup_handlers(
- lambda: logging.FileHandler(log_file_name),
- log,
- )
- # configure artifact loggers, note: this must happen last
- # since the levels of ancestor loggers are taken into account
- for artifact_log_qname in log_registry.get_artifact_log_qnames():
- log = logging.getLogger(artifact_log_qname)
- configure_artifact_log(log)
- # Setup handler for the special trace_log, with different default
- # configuration
- trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
- # This handler may remove itself if trace_dir_name is None and we are not
- # actually in an FB environment. This allows us to defer actually
- # initializing it until we actually need to log anything. This is
- # important because JK initializes a C++ singleton, which will pork our
- # process if we subsequently fork.
- handler = LazyTraceHandler(trace_dir_name)
- # This log is ALWAYS at debug level. We will additionally test if there
- # are any handlers before deciding to actually call logging on this. Do
- # not manually call
- trace_log.setLevel(logging.DEBUG)
- trace_log_handler = _track_handler(handler)
- trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
- trace_log.addHandler(trace_log_handler)
- class LazyTraceHandler(logging.StreamHandler):
- """Like FileHandler, but the file is allocated lazily only upon the first log message"""
- def __init__(self, root_dir: Optional[str]):
- # This is implemented in the same way that delay is implemented on
- # FileHandler
- self.root_dir = root_dir
- logging.Handler.__init__(self)
- self.stream = None
- self._builtin_open = open
- # cloned from FileHandler in cpython
- def close(self):
- self.acquire()
- try:
- try:
- if self.stream:
- try:
- self.flush()
- finally:
- stream = self.stream
- self.stream = None
- if hasattr(stream, "close"):
- stream.close()
- finally:
- # Issue #19523: call unconditionally to
- # prevent a handler leak when delay is set
- # Also see Issue #42378: we also rely on
- # self._closed being set to True there
- logging.StreamHandler.close(self)
- finally:
- self.release()
- def emit(self, record):
- if self.stream is None:
- ok = False
- if self.root_dir is None:
- TRACE_LOG_DIR = "/logs"
- open_func = self._builtin_open
- import torch.version as torch_version
- if hasattr(torch_version, "git_version"):
- log.info("LazyTraceHandler: disabled because not fbcode")
- elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
- log.info(
- "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
- )
- elif not os.path.exists(TRACE_LOG_DIR):
- log.info(
- "LazyTraceHandler: disabled because %s does not exist",
- TRACE_LOG_DIR,
- )
- elif not os.access(TRACE_LOG_DIR, os.W_OK):
- log.info(
- "LazyTraceHandler: disabled because %s is not writeable",
- TRACE_LOG_DIR,
- )
- else:
- self.root_dir = TRACE_LOG_DIR
- if self.root_dir is not None:
- os.makedirs(self.root_dir, exist_ok=True)
- ranksuffix = ""
- if dist.is_available() and dist.is_initialized():
- ranksuffix = f"rank_{dist.get_rank()}_"
- self.stream = tempfile.NamedTemporaryFile(
- mode="w+",
- suffix=".log",
- prefix=f"dedicated_log_torch_trace_{ranksuffix}",
- dir=self.root_dir,
- delete=False,
- )
- log.info("LazyTraceHandler: logging to %s", self.stream.name)
- else:
- # We go poof, remove and no-op
- trace_log.removeHandler(self)
- return
- if self.stream:
- super().emit(record)
- @functools.lru_cache(None)
- def warning_once(logger_obj, *args, **kwargs):
- """
- This function is similar to `logger.warning()`, but will emit the warning with the same message only once
- Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
- The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
- another type of cache that includes the caller frame information in the hashing function.
- """
- logger_obj.warning(*args, **kwargs)
- class LazyString:
- def __init__(self, func, *args, **kwargs):
- self.func = func
- self.args = args
- self.kwargs = kwargs
- def __str__(self):
- return self.func(*self.args, **self.kwargs)
- def trace_structured(
- name: str,
- # NB: metadata expected to be dict so adding more info is forward compatible
- # Tuple[str, int] is a special case for string interning
- metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
- *,
- payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
- suppress_context: bool = False,
- ):
- """
- metadata is an arbitrary JSON compatible struct, but it's expected to not be
- too long (e.g., less than 1MB)
- payload is an arbitrary string, which can be arbitrarily long (but expected to have
- newlines so no lines are too long)
- """
- assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"]
- assert callable(
- metadata_fn
- ), f"metadata_fn should be callable, but got {type(metadata_fn)}"
- assert callable(
- payload_fn
- ), f"payload_fn should be callable, but got {type(payload_fn)}"
- # trace_log never propagates and is ALWAYS DEBUG, so also check that there
- # are handlers instead of checking the log level
- if trace_log.handlers:
- record: Dict[str, object] = {}
- record[name] = metadata_fn()
- if not suppress_context:
- # TODO: Actually, the rank probably should just be emitted once at
- # the top, and not repeatedly spammed in all the logs, since it
- # never changes and we assume no interleaving
- if dist.is_available() and dist.is_initialized():
- record["rank"] = dist.get_rank()
- if (
- trace_id := torch._guards.CompileContext.current_trace_id()
- ) is not None:
- record["frame_id"] = trace_id.compile_id.frame_id
- record["frame_compile_id"] = trace_id.compile_id.frame_compile_id
- record["attempt"] = trace_id.attempt
- else:
- # Record the stack of the log call to better diagnose why we
- # don't have a frame id for it
- record["stack"] = torch._logging.structured.from_traceback(
- CapturedTraceback.extract(skip=1).summary()
- )
- payload = payload_fn()
- if payload is not None:
- if not isinstance(payload, str):
- if isinstance(payload, list):
- # special case to look better
- payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
- else:
- # force newlines so we are unlikely to overflow line limit
- payload = json.dumps(payload, indent=0)
- h = hashlib.md5()
- h.update(payload.encode("utf-8"))
- record["has_payload"] = h.hexdigest()
- trace_log.debug(
- "", extra={"metadata": record, "payload": payload}, stacklevel=2
- )
- import torch._guards
- import torch._utils_internal
- import torch.distributed as dist
|