| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import os
- from contextlib import contextmanager, ExitStack
- from typing import Generator
- from torch.distributed.elastic.multiprocessing.errors import record
- __all__ = [
- "worker_main",
- ]
- TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
- @contextmanager
- def _worker_server(socket_path: str) -> Generator[None, None, None]:
- from torch._C._distributed_c10d import _WorkerServer
- server = _WorkerServer(socket_path)
- try:
- yield
- finally:
- server.shutdown()
- @contextmanager
- @record
- def worker_main() -> Generator[None, None, None]:
- """
- This is a context manager that wraps your main entry function. This combines
- the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
- exposes handlers via a unix socket specified by
- ``Torch_WORKER_SERVER_SOCKET``.
- Example
- ::
- @worker_main()
- def main():
- pass
- if __name__=="__main__":
- main()
- """
- with ExitStack() as stack:
- socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
- if socket_path is not None:
- stack.enter_context(_worker_server(socket_path))
- yield
|