zero_redundancy_optimizer.pyi 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # mypy: allow-untyped-defs
  2. import enum
  3. from typing import Any, Callable, Dict, List, Optional, overload, Set, Type
  4. import torch
  5. from torch.distributed.algorithms.join import Joinable, JoinHook
  6. from torch.optim import Optimizer
  7. class _ZeROJoinHook(JoinHook):
  8. zero: Any = ...
  9. def __init__(self, zero: Any) -> None: ...
  10. def main_hook(self) -> None: ...
  11. class _DDPBucketAssignment:
  12. bucket_index: int
  13. parameters: List[torch.Tensor]
  14. offset: int
  15. device: torch.device
  16. tensor: Optional[torch.Tensor]
  17. class _OverlapStatus(enum.IntEnum):
  18. UNINITIALIZED: int = ...
  19. DDP_HAS_REBUILT_BUCKETS: int = ...
  20. INITIALIZED: int = ...
  21. class _OverlapInfo:
  22. status: Any = ...
  23. params_per_bucket: Any = ...
  24. params_per_rank: Any = ...
  25. offsets: Any = ...
  26. broadcast_handles: Any = ...
  27. bucket_index_to_future: Any = ...
  28. bucket_index_to_bucket: Any = ...
  29. bucket_indices_seen: Any = ...
  30. assigned_ranks_per_bucket: List[Set[int]] = ...
  31. total_size: int = ...
  32. shard_buckets: bool = ...
  33. def __init__(self) -> None: ...
  34. def wait_for_broadcasts(self) -> None: ...
  35. def clear_per_iter_info(self) -> None: ...
  36. class ZeroRedundancyOptimizer(Optimizer, Joinable):
  37. functional_optim_map: Any = ...
  38. initialized: bool = ...
  39. process_group: Any = ...
  40. world_size: int = ...
  41. rank: int = ...
  42. global_rank: int = ...
  43. parameters_as_bucket_view: bool = ...
  44. optim: Any = ...
  45. _device_to_device_index: Dict[torch.device, int] = ...
  46. _overlap_with_ddp: bool = ...
  47. _overlap_info: _OverlapInfo = ...
  48. _buckets: List[List[torch.Tensor]] = ...
  49. _bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
  50. def __init__(
  51. self,
  52. params: Any,
  53. optimizer_class: Type[Optimizer],
  54. process_group: Optional[Any] = ...,
  55. parameters_as_bucket_view: bool = ...,
  56. overlap_with_ddp: bool = ...,
  57. **defaults: Any,
  58. ) -> None: ...
  59. def add_param_group(self, param_group: Dict[str, Any]) -> None: ...
  60. def consolidate_state_dict(self, to: int = ...) -> None: ...
  61. @overload
  62. def step(self, closure: None = ..., **kwargs: Any) -> None: ...
  63. @overload
  64. def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
  65. def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
  66. def state_dict(self) -> Dict[str, Any]: ...
  67. def _local_step(
  68. self,
  69. gradients: Optional[List[Optional[torch.Tensor]]] = None,
  70. closure: Optional[Callable[[], float]] = None,
  71. **kwargs: Any,
  72. ) -> Optional[float]: ...
  73. def _get_assigned_rank(self, bucket_index: int) -> int: ...
  74. def _init_zero_for_overlap(self) -> None: ...
  75. def join_hook(self, **kwargs): ...
  76. @property
  77. def join_device(self) -> torch.device: ...
  78. def join_process_group(self) -> Any: ...