_distributed_c10d.pyi 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from datetime import timedelta
  4. from enum import Enum
  5. from typing import Any, Dict, List, Optional, overload, Tuple, Union
  6. import torch
  7. from torch import Tensor
  8. from torch._C import ScriptObject
  9. from torch.futures import Future
  10. # This module is defined in torch/csrc/distributed/c10d/init.cpp
  11. _DEFAULT_FIRST_BUCKET_BYTES: int
  12. _DEFAULT_NO_TIMEOUT: timedelta
  13. _DEFAULT_PG_TIMEOUT: timedelta
  14. _DEFAULT_PG_NCCL_TIMEOUT: timedelta
  15. class BuiltinCommHookType(Enum):
  16. ALLREDUCE = ...
  17. FP16_COMPRESS = ...
  18. def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
  19. def _register_builtin_comm_hook(
  20. reducer: Reducer,
  21. comm_hook_type: BuiltinCommHookType,
  22. ): ...
  23. def _set_global_rank(rank: int) -> None: ...
  24. def _hash_tensors(tensors: List[Tensor]) -> int: ...
  25. class GradBucket:
  26. def index(self) -> int: ...
  27. def buffer(self) -> Tensor: ...
  28. def gradients(self) -> List[Tensor]: ...
  29. def is_last(self) -> bool: ...
  30. def set_buffer(self, tensor: Tensor) -> None: ...
  31. def parameters(self) -> List[Tensor]: ...
  32. class Reducer:
  33. def __init__(
  34. self,
  35. params: List[Tensor],
  36. bucket_indices: List[List[int]],
  37. per_bucket_size_limits: List[int],
  38. process_group: ProcessGroup,
  39. expect_sparse_gradients: List[bool] = ...,
  40. bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
  41. find_unused_parameters: bool = ...,
  42. gradient_as_bucket_view: bool = ...,
  43. param_to_name_mapping: Dict[int, str] = ...,
  44. first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
  45. ): ...
  46. def prepare_for_forward(self) -> None: ...
  47. def prepare_for_backward(self, output: List[Tensor]) -> None: ...
  48. def get_backward_stats(self) -> List[int]: ...
  49. def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
  50. def _rebuild_buckets(self) -> bool: ...
  51. def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
  52. def _push_all_rebuilt_params(self) -> None: ...
  53. def _set_forward_pass_work_handle(
  54. self,
  55. work: Work,
  56. use_static_world_size: bool,
  57. ): ...
  58. def _get_local_used_map(self) -> Tensor: ...
  59. def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
  60. def _set_static_graph(self) -> None: ...
  61. def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
  62. def set_logger(self, logger: Logger) -> None: ...
  63. def _remove_autograd_hooks(self) -> None: ...
  64. def _check_reducer_finalized(self) -> None: ...
  65. def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
  66. def _reset_state(self) -> None: ...
  67. def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
  68. class DDPLoggingData:
  69. strs_map: Dict[str, str]
  70. ints_map: Dict[str, int]
  71. class Logger:
  72. def __init__(self, reducer: Reducer): ...
  73. def set_construction_data_and_log(
  74. self,
  75. module_name: str,
  76. device_ids: List[int],
  77. output_device: int,
  78. broadcast_buffers: bool,
  79. has_sync_bn: bool,
  80. static_graph: bool,
  81. ): ...
  82. def set_runtime_stats_and_log(self) -> None: ...
  83. def set_error_and_log(self, error: str) -> None: ...
  84. def _get_ddp_logging_data(self) -> DDPLoggingData: ...
  85. def _set_comm_hook_name(self, comm_hook: str) -> None: ...
  86. def _set_uneven_input_join(self) -> None: ...
  87. def _set_static_graph(self) -> None: ...
  88. class _WorkerServer:
  89. def __init__(self, socket_path: str) -> None: ...
  90. def shutdown(self) -> None: ...
  91. def get_debug_level(): ...
  92. def set_debug_level(): ...
  93. def set_debug_level_from_env(): ...
  94. class DebugLevel(Enum):
  95. OFF = ...
  96. INFO = ...
  97. DETAIL = ...
  98. class ReduceOp:
  99. def __init__(self, op: RedOpType): ...
  100. SUM: RedOpType = ...
  101. AVG: RedOpType = ...
  102. PRODUCT: RedOpType = ...
  103. MIN: RedOpType = ...
  104. MAX: RedOpType = ...
  105. BAND: RedOpType = ...
  106. BOR: RedOpType = ...
  107. BXOR: RedOpType = ...
  108. PREMUL_SUM: RedOpType = ...
  109. UNUSED: RedOpType = ...
  110. class RedOpType(Enum): ...
  111. class BroadcastOptions:
  112. rootRank: int
  113. rootTensor: int
  114. timeout: timedelta
  115. asyncOp: bool
  116. class AllreduceOptions:
  117. reduceOp: ReduceOp
  118. timeout: timedelta
  119. class AllreduceCoalescedOptions(AllreduceOptions): ...
  120. class ReduceOptions:
  121. reduceOp: ReduceOp
  122. rootRank: int
  123. rootTensor: int
  124. timeout: timedelta
  125. class AllgatherOptions:
  126. timeout: timedelta
  127. asyncOp: bool
  128. class GatherOptions:
  129. rootRank: int
  130. timeout: timedelta
  131. class ScatterOptions:
  132. rootRank: int
  133. timeout: timedelta
  134. asyncOp: bool
  135. class ReduceScatterOptions:
  136. reduceOp: ReduceOp
  137. timeout: timedelta
  138. asyncOp: bool
  139. class BarrierOptions:
  140. device_ids: List[int]
  141. device: torch.device
  142. timeout: timedelta
  143. class AllToAllOptions:
  144. timeout: timedelta
  145. class Store:
  146. def set(self, key: str, value: str): ...
  147. def get(self, key: str) -> bytes: ...
  148. def add(self, key: str, value: int) -> int: ...
  149. def compare_set(
  150. self,
  151. key: str,
  152. expected_value: str,
  153. desired_value: str,
  154. ) -> bytes: ...
  155. def delete_key(self, key: str) -> bool: ...
  156. def num_keys(self) -> int: ...
  157. def set_timeout(self, timeout: timedelta): ...
  158. @overload
  159. def wait(self, keys: List[str]): ...
  160. @overload
  161. def wait(self, keys: List[str], timeout: timedelta): ...
  162. class FileStore(Store):
  163. def __init__(self, path: str, numWorkers: int = ...): ...
  164. class HashStore(Store):
  165. def __init__(self): ...
  166. class TCPStore(Store):
  167. def __init__(
  168. self,
  169. host_name: str,
  170. port: int,
  171. world_size: Optional[int] = ...,
  172. is_master: bool = ...,
  173. timeout: timedelta = ...,
  174. wait_for_workers: bool = ...,
  175. multi_tenant: bool = ...,
  176. master_listen_fd: Optional[int] = ...,
  177. use_libuv: Optional[bool] = ...,
  178. ): ...
  179. @property
  180. def host(self) -> str: ...
  181. @property
  182. def port(self) -> int: ...
  183. class PrefixStore(Store):
  184. def __init__(self, prefix: str, store: Store): ...
  185. @property
  186. def underlying_store(self) -> Store: ...
  187. class _ControlCollectives:
  188. def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
  189. def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
  190. def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
  191. def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
  192. def gather_recv(self, key: str, timeout: timedelta) -> str: ...
  193. def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
  194. def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
  195. def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
  196. def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ...
  197. class _StoreCollectives(_ControlCollectives):
  198. def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
  199. class _DistributedBackendOptions:
  200. def __init__(self): ...
  201. @property
  202. def store(self) -> Store: ...
  203. @store.setter
  204. def store(self, store: Store) -> None: ...
  205. @property
  206. def group_rank(self) -> int: ...
  207. @group_rank.setter
  208. def group_rank(self, rank: int) -> None: ...
  209. @property
  210. def group_size(self) -> int: ...
  211. @group_size.setter
  212. def group_size(self, size: int) -> None: ...
  213. @property
  214. def timeout(self) -> timedelta: ...
  215. @timeout.setter
  216. def timeout(self, timeout: timedelta) -> None: ...
  217. @property
  218. def group_id(self) -> str: ...
  219. @group_id.setter
  220. def group_id(self, group_id: str) -> None: ...
  221. @property
  222. def global_ranks_in_group(self) -> List[int]: ...
  223. @global_ranks_in_group.setter
  224. def global_ranks_in_group(self, ranks: List[int]) -> None: ...
  225. class Work:
  226. def is_completed(self) -> bool: ...
  227. def is_success(self) -> bool: ...
  228. def exception(self) -> Any: ...
  229. def wait(self, timeout: timedelta = ...) -> bool: ...
  230. def get_future(self) -> Future: ...
  231. def source_rank(self) -> int: ...
  232. def _source_rank(self) -> int: ...
  233. def result(self) -> List[Tensor]: ...
  234. def synchronize(self): ...
  235. def boxed(self) -> ScriptObject: ...
  236. @staticmethod
  237. def unbox(obj: ScriptObject) -> Work: ...
  238. class Backend:
  239. def __init__(
  240. self,
  241. rank: int,
  242. size: int,
  243. ): ...
  244. @property
  245. def supports_splitting(self) -> bool: ...
  246. def rank(self) -> int: ...
  247. def size(self) -> int: ...
  248. def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
  249. def _set_sequence_number_for_group(self) -> None: ...
  250. class ProcessGroup:
  251. class Options:
  252. def __init__(self, backend: str, timeout: timedelta = ...): ...
  253. @property
  254. def backend(self) -> str: ...
  255. @property
  256. def _timeout(self) -> timedelta: ...
  257. @_timeout.setter
  258. def _timeout(self, val: timedelta) -> None: ...
  259. class BackendType(Enum):
  260. UNDEFINED = ...
  261. GLOO = ...
  262. NCCL = ...
  263. UCC = ...
  264. MPI = ...
  265. CUSTOM = ...
  266. def __init__(self, store: Store, rank: int, size: int, options: Options): ...
  267. def rank(self) -> int: ...
  268. def size(self) -> int: ...
  269. @overload
  270. def broadcast(
  271. self,
  272. tensors: List[Tensor],
  273. opts=...,
  274. ) -> Work: ...
  275. @overload
  276. def broadcast(
  277. self,
  278. tensor: Tensor,
  279. root: int,
  280. ) -> Work: ...
  281. @overload
  282. def allreduce(
  283. self,
  284. tensors: List[Tensor],
  285. opts: AllreduceOptions = ...,
  286. ) -> Work: ...
  287. @overload
  288. def allreduce(
  289. self,
  290. tensors: List[Tensor],
  291. op=...,
  292. ) -> Work: ...
  293. @overload
  294. def allreduce(
  295. self,
  296. tensor: Tensor,
  297. op=...,
  298. ) -> Work: ...
  299. def allreduce_coalesced(
  300. self,
  301. tensors: List[Tensor],
  302. opts=...,
  303. ) -> Work: ...
  304. def reduce_scatter_tensor_coalesced(
  305. self,
  306. outputTensors: List[Tensor],
  307. inputTensors: List[Tensor],
  308. opts: Optional[ReduceScatterOptions] = None,
  309. ) -> Work: ...
  310. @overload
  311. def reduce(
  312. self,
  313. tensors: List[Tensor],
  314. opts=...,
  315. ) -> Work: ...
  316. @overload
  317. def reduce(
  318. self,
  319. tensor: Tensor,
  320. root: int,
  321. op=...,
  322. ) -> Work: ...
  323. @overload
  324. def allgather(
  325. self,
  326. output_tensors: List[List[Tensor]],
  327. input_tensors: List[Tensor],
  328. opts=...,
  329. ) -> Work: ...
  330. @overload
  331. def allgather(
  332. self,
  333. output_tensors: List[Tensor],
  334. input_tensor: Tensor,
  335. ) -> Work: ...
  336. def _allgather_base(
  337. self,
  338. output: Tensor,
  339. input: Tensor,
  340. opts=...,
  341. ) -> Work: ...
  342. def allgather_coalesced(
  343. self,
  344. output_lists: List[List[Tensor]],
  345. input_list: List[Tensor],
  346. opts=...,
  347. ) -> Work: ...
  348. def allgather_into_tensor_coalesced(
  349. self,
  350. output_lists: List[Tensor],
  351. input_list: List[Tensor],
  352. opts=...,
  353. ) -> Work: ...
  354. @overload
  355. def gather(
  356. self,
  357. output_tensors: List[List[Tensor]],
  358. input_tensors: List[Tensor],
  359. opts=...,
  360. ) -> Work: ...
  361. @overload
  362. def gather(
  363. self,
  364. output_tensors: List[Tensor],
  365. input_tensor: Tensor,
  366. root: int,
  367. ) -> Work: ...
  368. @overload
  369. def scatter(
  370. self,
  371. output_tensors: List[Tensor],
  372. input_tensors: List[List[Tensor]],
  373. opts=...,
  374. ) -> Work: ...
  375. @overload
  376. def scatter(
  377. self,
  378. output_tensor: Tensor,
  379. input_tensors: List[Tensor],
  380. root: int,
  381. ) -> Work: ...
  382. @overload
  383. def reduce_scatter(
  384. self,
  385. output_tensors: List[Tensor],
  386. input_tensors: List[List[Tensor]],
  387. opts=...,
  388. ) -> Work: ...
  389. @overload
  390. def reduce_scatter(
  391. self,
  392. output_tensors: Tensor,
  393. input_tensor: List[Tensor],
  394. ) -> Work: ...
  395. def _reduce_scatter_base(
  396. self,
  397. outputTensor: Tensor,
  398. inputTensor: Tensor,
  399. opts: Optional[ReduceScatterOptions],
  400. ) -> Work: ...
  401. @overload
  402. def alltoall_base(
  403. self,
  404. output_tensor: Tensor,
  405. input_tensor: Tensor,
  406. output_split_sizes: List[int],
  407. input_split_sizes: List[int],
  408. opts=...,
  409. ) -> Work: ...
  410. @overload
  411. def alltoall_base(
  412. self,
  413. output: Tensor,
  414. input: Tensor,
  415. output_split_sizes: List[int],
  416. input_split_sizes: List[int],
  417. ) -> Work: ...
  418. @overload
  419. def alltoall(
  420. self,
  421. output_tensor: List[Tensor],
  422. input_tensor: List[Tensor],
  423. opts=...,
  424. ) -> Work: ...
  425. @overload
  426. def alltoall(
  427. self,
  428. output: List[Tensor],
  429. input: List[Tensor],
  430. ) -> Work: ...
  431. def send(
  432. self,
  433. tensors: List[Tensor],
  434. dstRank: int,
  435. tag: int,
  436. ) -> Work: ...
  437. def recv(
  438. self,
  439. tensors: List[Tensor],
  440. srcRank: int,
  441. tag: int,
  442. ) -> Work: ...
  443. def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
  444. def barrier(self, opts=...) -> Work: ...
  445. def boxed(self) -> ScriptObject: ...
  446. @staticmethod
  447. def unbox(obj: ScriptObject) -> ProcessGroup: ...
  448. def _start_coalescing(self, device: torch.device) -> None: ...
  449. def _end_coalescing(self, device: torch.device) -> Work: ...
  450. def _get_backend_name(self) -> str: ...
  451. def _backend_id(self, backend_type: BackendType) -> int: ...
  452. @property
  453. def _device_types(self) -> List[torch.device]: ...
  454. def _get_backend(self, device: torch.device) -> Backend: ...
  455. def _register_backend(
  456. self,
  457. device: torch.device,
  458. backend_type: BackendType,
  459. backend: Optional[Backend],
  460. ) -> None: ...
  461. def _set_group_name(self, name: str) -> None: ...
  462. def _set_group_desc(self, desc: str) -> None: ...
  463. def name(self) -> str: ...
  464. def _has_hooks(self) -> bool: ...
  465. def _wait_for_pending_works(self) -> None: ...
  466. def _set_sequence_number_for_group(self) -> None: ...
  467. @property
  468. def bound_device_id(self) -> Optional[torch.device]: ...
  469. @bound_device_id.setter
  470. def bound_device_id(self, device: Optional[torch.device]) -> None: ...
  471. @property
  472. def group_name(self) -> str: ...
  473. @property
  474. def group_desc(self) -> str: ...
  475. class ProcessGroupRoundRobin(ProcessGroup): ...
  476. def _round_robin_process_groups(
  477. process_groups: List[ProcessGroup],
  478. ) -> ProcessGroupRoundRobin: ...
  479. class ProcessGroupGloo(Backend):
  480. class Device: ...
  481. class Options: ...
  482. def __init__(
  483. self,
  484. store: Store,
  485. rank: int,
  486. size: int,
  487. timeout: timedelta,
  488. ): ...
  489. @staticmethod
  490. def create_device(hostname="", interface="") -> Device: ...
  491. @staticmethod
  492. def create_default_device() -> Device: ...
  493. def _set_default_timeout(self, timeout) -> None: ...
  494. class _ProcessGroupWrapper(Backend):
  495. def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ...
  496. wrapped_pg: Backend
  497. class ProcessGroupNCCL(Backend):
  498. class Options:
  499. def __init__(self, timeout: Optional[timedelta] = None): ...
  500. @property
  501. def backend(self) -> str: ...
  502. @property
  503. def _timeout(self) -> timedelta: ...
  504. @_timeout.setter
  505. def _timeout(self, val: timedelta) -> None: ...
  506. @property
  507. def _is_high_priority_stream(self) -> bool: ...
  508. @_is_high_priority_stream.setter
  509. def _is_high_priority_stream(self, val: bool) -> None: ...
  510. def __init__(
  511. self,
  512. store: Store,
  513. rank: int,
  514. size: int,
  515. timeout: timedelta,
  516. ): ...
  517. def _group_start(self) -> None: ...
  518. def _group_end(self) -> None: ...
  519. def _set_default_timeout(self, timeout) -> None: ...
  520. def _shutdown(self) -> None: ...
  521. @property
  522. def uid(self) -> int: ...
  523. class ProcessGroupUCC(Backend):
  524. def __init__(
  525. self,
  526. store: Store,
  527. rank: int,
  528. size: int,
  529. timeout: timedelta,
  530. ): ...
  531. class ProcessGroupMPI(Backend):
  532. def __init__(
  533. self,
  534. rank: int,
  535. size: int,
  536. pgComm: int,
  537. ): ...
  538. @staticmethod
  539. def create(ranks: List[int]) -> ProcessGroupMPI: ...
  540. def _compute_bucket_assignment_by_size(
  541. tensors: List[Tensor],
  542. bucket_size_limits: List[int],
  543. expect_sparse_gradient: List[bool] = ...,
  544. tensor_indices: List[int] = ...,
  545. ) -> Tuple[List[List[int]], List[int]]: ...
  546. def _broadcast_coalesced(
  547. process_group: ProcessGroup,
  548. tensors: List[Tensor],
  549. buffer_size: int,
  550. src: int,
  551. ): ...
  552. def _test_python_store(store: Store): ...
  553. def _verify_params_across_processes(
  554. process_group: ProcessGroup,
  555. params: List[Tensor],
  556. logger: Optional[Logger],
  557. ): ...
  558. def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
  559. def _register_process_group(
  560. group_name: str,
  561. process_group: ProcessGroup,
  562. ) -> None: ...
  563. def _resolve_process_group(group_name: str) -> ProcessGroup: ...
  564. def _unregister_all_process_groups() -> None: ...
  565. def _unregister_process_group(group_name: str) -> None: ...
  566. class ProcessGroupCudaP2P(Backend):
  567. class Options:
  568. nccl_options: Optional[ProcessGroupNCCL.Options]
  569. buffer_size: Optional[int]
  570. def __init__(self) -> None: ...
  571. def __init__(
  572. self,
  573. store: Store,
  574. rank: int,
  575. size: int,
  576. options: ProcessGroupCudaP2P.Options,
  577. ) -> None: ...
  578. def is_p2p_available(self) -> bool: ...
  579. def get_buffer_size(self) -> int: ...
  580. def stream(self) -> torch.cuda.Stream: ...
  581. def intra_node_barrier(self) -> Work: ...
  582. def get_p2p_buffer(
  583. self,
  584. rank: int,
  585. sizes: torch.Size,
  586. dtype: torch.dtype,
  587. storage_offset: Optional[int] = 0,
  588. ) -> torch.Tensor: ...
  589. def _shutdown(self) -> None: ...