logging.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import logging
  4. from torch.hub import _Faketqdm, tqdm
  5. # Disable progress bar by default, not in dynamo config because otherwise get a circular import
  6. disable_progress = True
  7. # Return all loggers that torchdynamo/torchinductor is responsible for
  8. def get_loggers():
  9. return [
  10. logging.getLogger("torch.fx.experimental.symbolic_shapes"),
  11. logging.getLogger("torch._dynamo"),
  12. logging.getLogger("torch._inductor"),
  13. ]
  14. # Creates a logging function that logs a message with a step # prepended.
  15. # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
  16. # so that step numbers are initialized properly. e.g.:
  17. # @functools.lru_cache(None)
  18. # def _step_logger():
  19. # return get_step_logger(logging.getLogger(...))
  20. # def fn():
  21. # _step_logger()(logging.INFO, "msg")
  22. _step_counter = itertools.count(1)
  23. # Update num_steps if more phases are added: Dynamo, AOT, Backend
  24. # This is very inductor centric
  25. # _inductor.utils.has_triton() gives a circular import error here
  26. if not disable_progress:
  27. try:
  28. import triton # noqa: F401
  29. num_steps = 3
  30. except ImportError:
  31. num_steps = 2
  32. pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
  33. def get_step_logger(logger):
  34. if not disable_progress:
  35. pbar.update(1)
  36. if not isinstance(pbar, _Faketqdm):
  37. pbar.set_postfix_str(f"{logger.name}")
  38. step = next(_step_counter)
  39. def log(level, msg, **kwargs):
  40. logger.log(level, "Step %s: %s", step, msg, **kwargs)
  41. return log