| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # mypy: allow-untyped-defs
- import traceback as tb
- from typing import Any, Dict, Tuple
- WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
- __all__ = ["CheckpointException"]
- def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
- return (exc, tb.extract_tb(exc.__traceback__))
- def _is_wrapped_exception(obj: Any) -> bool:
- if not isinstance(obj, tuple):
- return False
- if len(obj) != 2:
- return False
- return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
- class CheckpointException(BaseException):
- """Exception raised if failure was detected as part of a checkpoint load or save."""
- def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
- super().__init__(msg, failures)
- self._failures = failures
- @property
- def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
- """Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
- return self._failures
- def __str__(self):
- str = f"CheckpointException ranks:{self._failures.keys()}\n"
- for rank, exc_pair in self._failures.items():
- exc, trace = exc_pair
- str += f"Traceback (most recent call last): (RANK {rank})\n"
- if trace is not None:
- str += "".join(tb.format_list(trace))
- str += "".join(tb.format_exception_only(type(exc), value=exc))
- return str
|