stateful.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import Any, Dict, runtime_checkable, TypeVar
  2. from typing_extensions import Protocol
  3. __all__ = ["Stateful", "StatefulT"]
  4. @runtime_checkable
  5. class Stateful(Protocol):
  6. """
  7. Stateful protocol for objects that can be checkpointed and restored.
  8. """
  9. def state_dict(self) -> Dict[str, Any]:
  10. """
  11. Objects should return their state_dict representation as a dictionary.
  12. The output of this function will be checkpointed, and later restored in
  13. `load_state_dict()`.
  14. .. warning::
  15. Because of the inplace nature of restoring a checkpoint, this function
  16. is also called during `torch.distributed.checkpoint.load`.
  17. Returns:
  18. Dict: The objects state dict
  19. """
  20. ...
  21. def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
  22. """
  23. Restore the object's state from the provided state_dict.
  24. Args:
  25. state_dict: The state dict to restore from
  26. """
  27. ...
  28. StatefulT = TypeVar("StatefulT", bound=Stateful)