api.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # mypy: allow-untyped-defs
  2. import traceback as tb
  3. from typing import Any, Dict, Tuple
  4. WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
  5. __all__ = ["CheckpointException"]
  6. def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
  7. return (exc, tb.extract_tb(exc.__traceback__))
  8. def _is_wrapped_exception(obj: Any) -> bool:
  9. if not isinstance(obj, tuple):
  10. return False
  11. if len(obj) != 2:
  12. return False
  13. return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
  14. class CheckpointException(BaseException):
  15. """Exception raised if failure was detected as part of a checkpoint load or save."""
  16. def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
  17. super().__init__(msg, failures)
  18. self._failures = failures
  19. @property
  20. def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
  21. """Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
  22. return self._failures
  23. def __str__(self):
  24. str = f"CheckpointException ranks:{self._failures.keys()}\n"
  25. for rank, exc_pair in self._failures.items():
  26. exc, trace = exc_pair
  27. str += f"Traceback (most recent call last): (RANK {rank})\n"
  28. if trace is not None:
  29. str += "".join(tb.format_list(trace))
  30. str += "".join(tb.format_exception_only(type(exc), value=exc))
  31. return str