__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # mypy: allow-untyped-defs
  2. import os
  3. import sys
  4. from enum import Enum
  5. import pdb
  6. import io
  7. import torch
  8. def is_available() -> bool:
  9. """
  10. Return ``True`` if the distributed package is available.
  11. Otherwise,
  12. ``torch.distributed`` does not expose any other APIs. Currently,
  13. ``torch.distributed`` is available on Linux, MacOS and Windows. Set
  14. ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
  15. Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
  16. ``USE_DISTRIBUTED=0`` for MacOS.
  17. """
  18. return hasattr(torch._C, "_c10d_init")
  19. if is_available() and not torch._C._c10d_init():
  20. raise RuntimeError("Failed to initialize torch.distributed")
  21. # Custom Runtime Errors thrown from the distributed package
  22. DistError = torch._C._DistError
  23. DistBackendError = torch._C._DistBackendError
  24. DistNetworkError = torch._C._DistNetworkError
  25. DistStoreError = torch._C._DistStoreError
  26. if is_available():
  27. from torch._C._distributed_c10d import (
  28. Store,
  29. FileStore,
  30. TCPStore,
  31. ProcessGroup as ProcessGroup,
  32. Backend as _Backend,
  33. PrefixStore,
  34. Reducer,
  35. Logger,
  36. BuiltinCommHookType,
  37. GradBucket,
  38. Work as _Work,
  39. _DEFAULT_FIRST_BUCKET_BYTES,
  40. _register_comm_hook,
  41. _register_builtin_comm_hook,
  42. _broadcast_coalesced,
  43. _compute_bucket_assignment_by_size,
  44. _verify_params_across_processes,
  45. _test_python_store,
  46. DebugLevel,
  47. get_debug_level,
  48. set_debug_level,
  49. set_debug_level_from_env,
  50. _make_nccl_premul_sum,
  51. _ControlCollectives,
  52. _StoreCollectives,
  53. )
  54. class _DistributedPdb(pdb.Pdb):
  55. """
  56. Supports using PDB from inside a multiprocessing child process.
  57. Usage:
  58. _DistributedPdb().set_trace()
  59. """
  60. def interaction(self, *args, **kwargs):
  61. _stdin = sys.stdin
  62. try:
  63. sys.stdin = open('/dev/stdin')
  64. pdb.Pdb.interaction(self, *args, **kwargs)
  65. finally:
  66. sys.stdin = _stdin
  67. def breakpoint(rank: int = 0):
  68. """
  69. Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
  70. done with the breakpoint before continuing.
  71. Args:
  72. rank (int): Which rank to break on. Default: ``0``
  73. """
  74. if get_rank() == rank:
  75. pdb = _DistributedPdb()
  76. pdb.message(
  77. "\n!!! ATTENTION !!!\n\n"
  78. f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
  79. )
  80. pdb.set_trace()
  81. # If Meta/Python keys are in the TLS, we want to make sure that we ignore them
  82. # and hit the (default) CPU/CUDA implementation of barrier.
  83. meta_in_tls = torch._C._meta_in_tls_dispatch_include()
  84. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
  85. torch._C._set_meta_in_tls_dispatch_include(False)
  86. try:
  87. barrier()
  88. finally:
  89. torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
  90. del guard
  91. if sys.platform != "win32":
  92. from torch._C._distributed_c10d import (
  93. HashStore,
  94. _round_robin_process_groups,
  95. )
  96. from .distributed_c10d import * # noqa: F403
  97. # Variables prefixed with underscore are not auto imported
  98. # See the comment in `distributed_c10d.py` above `_backend` on why we expose
  99. # this.
  100. from .distributed_c10d import (
  101. _all_gather_base,
  102. _reduce_scatter_base,
  103. _create_process_group_wrapper,
  104. _rank_not_in_group,
  105. _coalescing_manager,
  106. _CoalescingManager,
  107. _get_process_group_name,
  108. get_node_local_rank,
  109. )
  110. from .rendezvous import (
  111. rendezvous,
  112. _create_store_from_options,
  113. register_rendezvous_handler,
  114. )
  115. from .remote_device import _remote_device
  116. from .device_mesh import init_device_mesh, DeviceMesh
  117. set_debug_level_from_env()
  118. else:
  119. # This stub is sufficient to get
  120. # python test/test_public_bindings.py -k test_correct_module_names
  121. # working even when USE_DISTRIBUTED=0. Feel free to add more
  122. # stubs as necessary.
  123. # We cannot define stubs directly because they confuse pyre
  124. class _ProcessGroupStub:
  125. pass
  126. sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]