| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- # mypy: ignore-errors
- import torch._dynamo.test_case
- import unittest.mock
- import os
- import contextlib
- import torch._logging
- import torch._logging._internal
- from torch._dynamo.utils import LazyString
- from torch._inductor import config as inductor_config
- import logging
- import io
- @contextlib.contextmanager
- def preserve_log_state():
- prev_state = torch._logging._internal._get_log_state()
- torch._logging._internal._set_log_state(torch._logging._internal.LogState())
- try:
- yield
- finally:
- torch._logging._internal._set_log_state(prev_state)
- torch._logging._internal._init_logs()
- def log_settings(settings):
- exit_stack = contextlib.ExitStack()
- settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
- exit_stack.enter_context(preserve_log_state())
- exit_stack.enter_context(settings_patch)
- torch._logging._internal._init_logs()
- return exit_stack
- def log_api(**kwargs):
- exit_stack = contextlib.ExitStack()
- exit_stack.enter_context(preserve_log_state())
- torch._logging.set_logs(**kwargs)
- return exit_stack
- def kwargs_to_settings(**kwargs):
- INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
- settings = []
- def append_setting(name, level):
- if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
- settings.append(INT_TO_VERBOSITY[level] + name)
- return
- else:
- raise ValueError("Invalid value for setting")
- for name, val in kwargs.items():
- if isinstance(val, bool):
- settings.append(name)
- elif isinstance(val, int):
- append_setting(name, val)
- elif isinstance(val, dict) and name == "modules":
- for module_qname, level in val.items():
- append_setting(module_qname, level)
- else:
- raise ValueError("Invalid value for setting")
- return ",".join(settings)
- # Note on testing strategy:
- # This class does two things:
- # 1. Runs two versions of a test:
- # 1a. patches the env var log settings to some specific value
- # 1b. calls torch._logging.set_logs(..)
- # 2. patches the emit method of each setup handler to gather records
- # that are emitted to each console stream
- # 3. passes a ref to the gathered records to each test case for checking
- #
- # The goal of this testing in general is to ensure that given some settings env var
- # that the logs are setup correctly and capturing the correct records.
- def make_logging_test(**kwargs):
- def wrapper(fn):
- @inductor_config.patch({"fx_graph_cache": False})
- def test_fn(self):
- torch._dynamo.reset()
- records = []
- # run with env var
- if len(kwargs) == 0:
- with self._handler_watcher(records):
- fn(self, records)
- else:
- with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
- fn(self, records)
- # run with API
- torch._dynamo.reset()
- records.clear()
- with log_api(**kwargs), self._handler_watcher(records):
- fn(self, records)
- return test_fn
- return wrapper
- def make_settings_test(settings):
- def wrapper(fn):
- def test_fn(self):
- torch._dynamo.reset()
- records = []
- # run with env var
- with log_settings(settings), self._handler_watcher(records):
- fn(self, records)
- return test_fn
- return wrapper
- class LoggingTestCase(torch._dynamo.test_case.TestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls._exit_stack.enter_context(
- unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
- )
- cls._exit_stack.enter_context(
- unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
- )
- cls._exit_stack.enter_context(
- unittest.mock.patch("torch._dynamo.config.verbose", False)
- )
- @classmethod
- def tearDownClass(cls):
- cls._exit_stack.close()
- torch._logging._internal.log_state.clear()
- torch._logging._init_logs()
- def hasRecord(self, records, m):
- return any(m in r.getMessage() for r in records)
- def getRecord(self, records, m):
- record = None
- for r in records:
- # NB: not r.msg because it looks like 3.11 changed how they
- # structure log records
- if m in r.getMessage():
- self.assertIsNone(
- record,
- msg=LazyString(
- lambda: f"multiple matching records: {record} and {r} among {records}"
- ),
- )
- record = r
- if record is None:
- self.fail(f"did not find record with {m} among {records}")
- return record
- # This patches the emit method of each handler to gather records
- # as they are emitted
- def _handler_watcher(self, record_list):
- exit_stack = contextlib.ExitStack()
- def emit_post_hook(record):
- nonlocal record_list
- record_list.append(record)
- # registered logs are the only ones with handlers, so patch those
- for log_qname in torch._logging._internal.log_registry.get_log_qnames():
- logger = logging.getLogger(log_qname)
- num_handlers = len(logger.handlers)
- self.assertLessEqual(
- num_handlers,
- 2,
- "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
- )
- self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
- for handler in logger.handlers:
- old_emit = handler.emit
- def new_emit(record):
- old_emit(record)
- emit_post_hook(record)
- exit_stack.enter_context(
- unittest.mock.patch.object(handler, "emit", new_emit)
- )
- return exit_stack
- def logs_to_string(module, log_option):
- """Example:
- logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
- returns the output of TORCH_LOGS="post_grad_graphs" from the
- torch._inductor.compile_fx module.
- """
- log_stream = io.StringIO()
- handler = logging.StreamHandler(stream=log_stream)
- @contextlib.contextmanager
- def tmp_redirect_logs():
- try:
- logger = torch._logging.getArtifactLogger(module, log_option)
- logger.addHandler(handler)
- yield
- finally:
- logger.removeHandler(handler)
- def ctx_manager():
- exit_stack = log_settings(log_option)
- exit_stack.enter_context(tmp_redirect_logs())
- return exit_stack
- return log_stream, ctx_manager
|