log_utils.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import logging
  2. import logging.config
  3. import os
  4. from typing import Optional
  5. import torch.distributed as dist
  6. LOGGING_CONFIG = {
  7. "version": 1,
  8. "formatters": {
  9. "spmd_format": {"format": "%(name)s: [%(levelname)s] %(message)s"},
  10. "graph_opt_format": {"format": "%(name)s: [%(levelname)s] %(message)s"},
  11. },
  12. "handlers": {
  13. "spmd_console": {
  14. "class": "logging.StreamHandler",
  15. "level": "DEBUG",
  16. "formatter": "spmd_format",
  17. "stream": "ext://sys.stdout",
  18. },
  19. "graph_opt_console": {
  20. "class": "logging.StreamHandler",
  21. "level": "DEBUG",
  22. "formatter": "graph_opt_format",
  23. "stream": "ext://sys.stdout",
  24. },
  25. "null_console": {
  26. "class": "logging.NullHandler",
  27. },
  28. },
  29. "loggers": {
  30. "spmd_exp": {
  31. "level": "DEBUG",
  32. "handlers": ["spmd_console"],
  33. "propagate": False,
  34. },
  35. "graph_opt": {
  36. "level": "DEBUG",
  37. "handlers": ["graph_opt_console"],
  38. "propagate": False,
  39. },
  40. "null_logger": {
  41. "handlers": ["null_console"],
  42. "propagate": False,
  43. },
  44. # TODO(anj): Add loggers for MPMD
  45. },
  46. "disable_existing_loggers": False,
  47. }
  48. def get_logger(log_type: str) -> Optional[logging.Logger]:
  49. from torch.distributed._spmd import config
  50. if "PYTEST_CURRENT_TEST" not in os.environ:
  51. logging.config.dictConfig(LOGGING_CONFIG)
  52. avail_loggers = list(LOGGING_CONFIG["loggers"].keys()) # type: ignore[attr-defined]
  53. assert (
  54. log_type in avail_loggers
  55. ), f"Unable to find {log_type} in the available list of loggers {avail_loggers}"
  56. if not dist.is_initialized():
  57. return logging.getLogger(log_type)
  58. if dist.get_rank() == 0:
  59. logger = logging.getLogger(log_type)
  60. logger.setLevel(config.log_level)
  61. if config.log_file_name is not None:
  62. log_file = logging.FileHandler(config.log_file_name)
  63. log_file.setLevel(config.log_level)
  64. logger.addHandler(log_file)
  65. else:
  66. logger = logging.getLogger("null_logger")
  67. return logger
  68. return logging.getLogger("null_logger")