logging_utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # mypy: ignore-errors
  2. import torch._dynamo.test_case
  3. import unittest.mock
  4. import os
  5. import contextlib
  6. import torch._logging
  7. import torch._logging._internal
  8. from torch._dynamo.utils import LazyString
  9. from torch._inductor import config as inductor_config
  10. import logging
  11. import io
  12. @contextlib.contextmanager
  13. def preserve_log_state():
  14. prev_state = torch._logging._internal._get_log_state()
  15. torch._logging._internal._set_log_state(torch._logging._internal.LogState())
  16. try:
  17. yield
  18. finally:
  19. torch._logging._internal._set_log_state(prev_state)
  20. torch._logging._internal._init_logs()
  21. def log_settings(settings):
  22. exit_stack = contextlib.ExitStack()
  23. settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
  24. exit_stack.enter_context(preserve_log_state())
  25. exit_stack.enter_context(settings_patch)
  26. torch._logging._internal._init_logs()
  27. return exit_stack
  28. def log_api(**kwargs):
  29. exit_stack = contextlib.ExitStack()
  30. exit_stack.enter_context(preserve_log_state())
  31. torch._logging.set_logs(**kwargs)
  32. return exit_stack
  33. def kwargs_to_settings(**kwargs):
  34. INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
  35. settings = []
  36. def append_setting(name, level):
  37. if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
  38. settings.append(INT_TO_VERBOSITY[level] + name)
  39. return
  40. else:
  41. raise ValueError("Invalid value for setting")
  42. for name, val in kwargs.items():
  43. if isinstance(val, bool):
  44. settings.append(name)
  45. elif isinstance(val, int):
  46. append_setting(name, val)
  47. elif isinstance(val, dict) and name == "modules":
  48. for module_qname, level in val.items():
  49. append_setting(module_qname, level)
  50. else:
  51. raise ValueError("Invalid value for setting")
  52. return ",".join(settings)
  53. # Note on testing strategy:
  54. # This class does two things:
  55. # 1. Runs two versions of a test:
  56. # 1a. patches the env var log settings to some specific value
  57. # 1b. calls torch._logging.set_logs(..)
  58. # 2. patches the emit method of each setup handler to gather records
  59. # that are emitted to each console stream
  60. # 3. passes a ref to the gathered records to each test case for checking
  61. #
  62. # The goal of this testing in general is to ensure that given some settings env var
  63. # that the logs are setup correctly and capturing the correct records.
  64. def make_logging_test(**kwargs):
  65. def wrapper(fn):
  66. @inductor_config.patch({"fx_graph_cache": False})
  67. def test_fn(self):
  68. torch._dynamo.reset()
  69. records = []
  70. # run with env var
  71. if len(kwargs) == 0:
  72. with self._handler_watcher(records):
  73. fn(self, records)
  74. else:
  75. with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
  76. fn(self, records)
  77. # run with API
  78. torch._dynamo.reset()
  79. records.clear()
  80. with log_api(**kwargs), self._handler_watcher(records):
  81. fn(self, records)
  82. return test_fn
  83. return wrapper
  84. def make_settings_test(settings):
  85. def wrapper(fn):
  86. def test_fn(self):
  87. torch._dynamo.reset()
  88. records = []
  89. # run with env var
  90. with log_settings(settings), self._handler_watcher(records):
  91. fn(self, records)
  92. return test_fn
  93. return wrapper
  94. class LoggingTestCase(torch._dynamo.test_case.TestCase):
  95. @classmethod
  96. def setUpClass(cls):
  97. super().setUpClass()
  98. cls._exit_stack.enter_context(
  99. unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
  100. )
  101. cls._exit_stack.enter_context(
  102. unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
  103. )
  104. cls._exit_stack.enter_context(
  105. unittest.mock.patch("torch._dynamo.config.verbose", False)
  106. )
  107. @classmethod
  108. def tearDownClass(cls):
  109. cls._exit_stack.close()
  110. torch._logging._internal.log_state.clear()
  111. torch._logging._init_logs()
  112. def hasRecord(self, records, m):
  113. return any(m in r.getMessage() for r in records)
  114. def getRecord(self, records, m):
  115. record = None
  116. for r in records:
  117. # NB: not r.msg because it looks like 3.11 changed how they
  118. # structure log records
  119. if m in r.getMessage():
  120. self.assertIsNone(
  121. record,
  122. msg=LazyString(
  123. lambda: f"multiple matching records: {record} and {r} among {records}"
  124. ),
  125. )
  126. record = r
  127. if record is None:
  128. self.fail(f"did not find record with {m} among {records}")
  129. return record
  130. # This patches the emit method of each handler to gather records
  131. # as they are emitted
  132. def _handler_watcher(self, record_list):
  133. exit_stack = contextlib.ExitStack()
  134. def emit_post_hook(record):
  135. nonlocal record_list
  136. record_list.append(record)
  137. # registered logs are the only ones with handlers, so patch those
  138. for log_qname in torch._logging._internal.log_registry.get_log_qnames():
  139. logger = logging.getLogger(log_qname)
  140. num_handlers = len(logger.handlers)
  141. self.assertLessEqual(
  142. num_handlers,
  143. 2,
  144. "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
  145. )
  146. self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
  147. for handler in logger.handlers:
  148. old_emit = handler.emit
  149. def new_emit(record):
  150. old_emit(record)
  151. emit_post_hook(record)
  152. exit_stack.enter_context(
  153. unittest.mock.patch.object(handler, "emit", new_emit)
  154. )
  155. return exit_stack
  156. def logs_to_string(module, log_option):
  157. """Example:
  158. logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
  159. returns the output of TORCH_LOGS="post_grad_graphs" from the
  160. torch._inductor.compile_fx module.
  161. """
  162. log_stream = io.StringIO()
  163. handler = logging.StreamHandler(stream=log_stream)
  164. @contextlib.contextmanager
  165. def tmp_redirect_logs():
  166. try:
  167. logger = torch._logging.getArtifactLogger(module, log_option)
  168. logger.addHandler(handler)
  169. yield
  170. finally:
  171. logger.removeHandler(handler)
  172. def ctx_manager():
  173. exit_stack = log_settings(log_option)
  174. exit_stack.enter_context(tmp_redirect_logs())
  175. return exit_stack
  176. return log_stream, ctx_manager