| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # mypy: allow-untyped-defs
- import functools
- import time
- from typing import Any, Callable, Dict, List, TypeVar
- from typing_extensions import ParamSpec
- import torch.distributed.c10d_logger as c10d_logger
- from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
- __all__: List[str] = []
- global _dcp_logger
- _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
- """
- Extracts log data from dcp method args
- """
- msg_dict = {}
- # checkpoint ID can be passed in through the serializer or through the checkpoint id directly
- storage_writer = kwargs.get("storage_writer", None)
- storage_reader = kwargs.get("storage_reader", None)
- checkpoint_id = kwargs.get("checkpoint_id", None)
- if not checkpoint_id and (serializer := storage_writer or storage_reader):
- checkpoint_id = getattr(serializer, "checkpoint_id", None)
- msg_dict["checkpoint_id"] = (
- str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
- )
- return msg_dict
- def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
- msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
- msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict))
- return msg_dict
- def _dcp_method_logger(
- log_exceptions: bool = False, **wrapper_kwargs: Any
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
- """This method decorator logs the start, end, and exception of wrapped events."""
- def decorator(func: Callable[_P, _T]):
- @functools.wraps(func)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- msg_dict = _get_msg_dict(
- func.__name__, *args, **{**wrapper_kwargs, **kwargs}
- )
- # log start event
- msg_dict["event"] = "start"
- t0 = time.time_ns()
- msg_dict["time"] = t0
- _dcp_logger.debug(msg_dict)
- # exceptions
- try:
- result = func(*args, **kwargs)
- except Exception as error:
- if log_exceptions:
- msg_dict["event"] = "exception"
- msg_dict["error"] = f"{error}"
- msg_dict["time"] = time.time_ns()
- _dcp_logger.error(msg_dict)
- raise
- # end event
- msg_dict["event"] = "end"
- t1 = time.time_ns()
- msg_dict["time"] = time.time_ns()
- msg_dict["times_spent"] = t1 - t0
- _dcp_logger.debug(msg_dict)
- return result
- return wrapper
- return decorator
|