control_plane.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. from contextlib import contextmanager, ExitStack
  3. from typing import Generator
  4. from torch.distributed.elastic.multiprocessing.errors import record
  5. __all__ = [
  6. "worker_main",
  7. ]
  8. TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
  9. @contextmanager
  10. def _worker_server(socket_path: str) -> Generator[None, None, None]:
  11. from torch._C._distributed_c10d import _WorkerServer
  12. server = _WorkerServer(socket_path)
  13. try:
  14. yield
  15. finally:
  16. server.shutdown()
  17. @contextmanager
  18. @record
  19. def worker_main() -> Generator[None, None, None]:
  20. """
  21. This is a context manager that wraps your main entry function. This combines
  22. the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
  23. exposes handlers via a unix socket specified by
  24. ``Torch_WORKER_SERVER_SOCKET``.
  25. Example
  26. ::
  27. @worker_main()
  28. def main():
  29. pass
  30. if __name__=="__main__":
  31. main()
  32. """
  33. with ExitStack() as stack:
  34. socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
  35. if socket_path is not None:
  36. stack.enter_context(_worker_server(socket_path))
  37. yield