logger.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import time
  4. from typing import Any, Callable, Dict, List, TypeVar
  5. from typing_extensions import ParamSpec
  6. import torch.distributed.c10d_logger as c10d_logger
  7. from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
  8. __all__: List[str] = []
  9. global _dcp_logger
  10. _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
  11. _T = TypeVar("_T")
  12. _P = ParamSpec("_P")
  13. def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
  14. """
  15. Extracts log data from dcp method args
  16. """
  17. msg_dict = {}
  18. # checkpoint ID can be passed in through the serializer or through the checkpoint id directly
  19. storage_writer = kwargs.get("storage_writer", None)
  20. storage_reader = kwargs.get("storage_reader", None)
  21. checkpoint_id = kwargs.get("checkpoint_id", None)
  22. if not checkpoint_id and (serializer := storage_writer or storage_reader):
  23. checkpoint_id = getattr(serializer, "checkpoint_id", None)
  24. msg_dict["checkpoint_id"] = (
  25. str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
  26. )
  27. return msg_dict
  28. def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
  29. msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
  30. msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict))
  31. return msg_dict
  32. def _dcp_method_logger(
  33. log_exceptions: bool = False, **wrapper_kwargs: Any
  34. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
  35. """This method decorator logs the start, end, and exception of wrapped events."""
  36. def decorator(func: Callable[_P, _T]):
  37. @functools.wraps(func)
  38. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  39. msg_dict = _get_msg_dict(
  40. func.__name__, *args, **{**wrapper_kwargs, **kwargs}
  41. )
  42. # log start event
  43. msg_dict["event"] = "start"
  44. t0 = time.time_ns()
  45. msg_dict["time"] = t0
  46. _dcp_logger.debug(msg_dict)
  47. # exceptions
  48. try:
  49. result = func(*args, **kwargs)
  50. except Exception as error:
  51. if log_exceptions:
  52. msg_dict["event"] = "exception"
  53. msg_dict["error"] = f"{error}"
  54. msg_dict["time"] = time.time_ns()
  55. _dcp_logger.error(msg_dict)
  56. raise
  57. # end event
  58. msg_dict["event"] = "end"
  59. t1 = time.time_ns()
  60. msg_dict["time"] = time.time_ns()
  61. msg_dict["times_spent"] = t1 - t0
  62. _dcp_logger.debug(msg_dict)
  63. return result
  64. return wrapper
  65. return decorator