run.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. """
  9. Superset of ``torch.distributed.launch``.
  10. ``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
  11. with the following additional functionalities:
  12. 1. Worker failures are handled gracefully by restarting all workers.
  13. 2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
  14. 3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
  15. .. note:: ``torchrun`` is a python
  16. `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
  17. to the main module
  18. `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
  19. declared in the ``entry_points`` configuration in
  20. `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
  21. It is equivalent to invoking ``python -m torch.distributed.run``.
  22. Transitioning from torch.distributed.launch to torchrun
  23. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  24. ``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
  25. for ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch``
  26. to ``torchrun`` follow these steps:
  27. 1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
  28. Then you need simply omit the ``--use-env`` flag, e.g.:
  29. +--------------------------------------------------------------------+--------------------------------------------+
  30. | ``torch.distributed.launch`` | ``torchrun`` |
  31. +====================================================================+============================================+
  32. | | |
  33. | .. code-block:: shell-session | .. code-block:: shell-session |
  34. | | |
  35. | $ python -m torch.distributed.launch --use-env train_script.py | $ torchrun train_script.py |
  36. | | |
  37. +--------------------------------------------------------------------+--------------------------------------------+
  38. 2. If your training script reads local rank from a ``--local-rank`` cmd argument.
  39. Change your training script to read from the ``LOCAL_RANK`` environment variable as
  40. demonstrated by the following code snippet:
  41. +-------------------------------------------------------+----------------------------------------------------+
  42. | ``torch.distributed.launch`` | ``torchrun`` |
  43. +=======================================================+====================================================+
  44. | | |
  45. | .. code-block:: python | .. code-block:: python |
  46. | | |
  47. | | |
  48. | import argparse | import os |
  49. | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) |
  50. | parser.add_argument("--local-rank", type=int) | |
  51. | args = parser.parse_args() | |
  52. | | |
  53. | local_rank = args.local_rank | |
  54. | | |
  55. +-------------------------------------------------------+----------------------------------------------------+
  56. .. versionchanged:: 2.0.0
  57. The launcher will pass the ``--local-rank=<rank>`` argument to your script.
  58. From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
  59. previously used underscored ``--local_rank``.
  60. For backward compatibility, it may be necessary for users to handle both
  61. cases in their argument parsing code. This means including both ``"--local-rank"``
  62. and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
  63. provided, the launcher will trigger an error: "error: unrecognized arguments:
  64. --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
  65. including ``"--local-rank"`` should be sufficient.
  66. ::
  67. >>> # xdoctest: +SKIP
  68. >>> import argparse
  69. >>> parser = argparse.ArgumentParser()
  70. >>> parser.add_argument("--local-rank", "--local_rank", type=int)
  71. >>> args = parser.parse_args()
  72. The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
  73. To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
  74. please refer to:
  75. * :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
  76. * the rest of this page for more information on the features of ``torchrun``.
  77. Usage
  78. --------
  79. Single-node multi-worker
  80. ++++++++++++++++++++++++++++++
  81. ::
  82. torchrun
  83. --standalone
  84. --nnodes=1
  85. --nproc-per-node=$NUM_TRAINERS
  86. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  87. Stacked single-node multi-worker
  88. +++++++++++++++++++++++++++++++++++
  89. To run multiple instances (separate jobs) of single-node, multi-worker on the
  90. same host, we need to make sure that each instance (job) is
  91. setup on different ports to avoid port conflicts (or worse, two jobs being merged
  92. as a single job). To do this you have to run with ``--rdzv-backend=c10d``
  93. and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
  94. For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
  95. port automatically instead of manually assigning different ports for each run.
  96. ::
  97. torchrun
  98. --rdzv-backend=c10d
  99. --rdzv-endpoint=localhost:0
  100. --nnodes=1
  101. --nproc-per-node=$NUM_TRAINERS
  102. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  103. Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
  104. ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  105. ::
  106. torchrun
  107. --nnodes=$NUM_NODES
  108. --nproc-per-node=$NUM_TRAINERS
  109. --max-restarts=3
  110. --rdzv-id=$JOB_ID
  111. --rdzv-backend=c10d
  112. --rdzv-endpoint=$HOST_NODE_ADDR
  113. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  114. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  115. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  116. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  117. .. note::
  118. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  119. Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
  120. +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  121. ::
  122. torchrun
  123. --nnodes=1:4
  124. --nproc-per-node=$NUM_TRAINERS
  125. --max-restarts=3
  126. --rdzv-id=$JOB_ID
  127. --rdzv-backend=c10d
  128. --rdzv-endpoint=$HOST_NODE_ADDR
  129. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  130. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  131. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  132. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  133. .. note::
  134. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  135. Note on rendezvous backend
  136. ------------------------------
  137. For multi-node training you need to specify:
  138. 1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
  139. 2. ``--rdzv-backend``: An implementation of
  140. :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
  141. 3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
  142. ``host:port``.
  143. Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are
  144. supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
  145. enabled (e.g. ``--enable-v2``).
  146. .. warning::
  147. ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
  148. server. Our tests use etcd v3.4.3.
  149. .. warning::
  150. For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
  151. equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
  152. removed in a future version.
  153. Definitions
  154. --------------
  155. 1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
  156. 2. ``Worker`` - A worker in the context of distributed training.
  157. 3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
  158. 4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
  159. 5. ``RANK`` - The rank of the worker within a worker group.
  160. 6. ``WORLD_SIZE`` - The total number of workers in a worker group.
  161. 7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
  162. 8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
  163. 9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
  164. used by each node to join as a member of a particular worker group.
  165. 9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
  166. consistent key-value store.
  167. 10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
  168. A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
  169. all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
  170. Environment Variables
  171. ----------------------
  172. The following environment variables are made available to you in your script:
  173. 1. ``LOCAL_RANK`` - The local rank.
  174. 2. ``RANK`` - The global rank.
  175. 3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
  176. running a single worker group per node, this is the rank of the node.
  177. 4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role
  178. of the worker is specified in the ``WorkerSpec``.
  179. 5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
  180. ``--nproc-per-node`` specified on ``torchrun``.
  181. 6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
  182. 7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
  183. in ``WorkerSpec``.
  184. 8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
  185. the Torch Distributed backend.
  186. 9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
  187. 10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
  188. 11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
  189. 12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
  190. 13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
  191. use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
  192. Deployment
  193. ------------
  194. 1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
  195. passed as ``--rdzv-endpoint`` to the launcher script)
  196. 2. Single-node multi-worker: Start the launcher on the host to start the agent process which
  197. creates and monitors a local worker group.
  198. 3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
  199. participating in training.
  200. When using a job/cluster manager the entry point command to the multi-node job should be this
  201. launcher.
  202. Failure Modes
  203. ---------------
  204. 1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
  205. are stopped and restarted up to ``max_restarts``.
  206. 2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
  207. manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
  208. are supported by the agent.
  209. 3. Node failure: Same as agent failure.
  210. Membership Changes
  211. --------------------
  212. 1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
  213. stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  214. ``WORLD_SIZE``.
  215. 2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
  216. a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  217. ``WORLD_SIZE``.
  218. Important Notices
  219. --------------------
  220. 1. This utility and multi-process distributed (single-node or
  221. multi-node) GPU training currently only achieves the best performance using
  222. the NCCL distributed backend. Thus NCCL backend is the recommended backend to
  223. use for GPU training.
  224. 2. The environment variables necessary to initialize a Torch process group are provided to you by
  225. this module, no need for you to pass ``RANK`` manually. To initialize a process group in your
  226. training script, simply run:
  227. ::
  228. >>> # xdoctest: +SKIP("stub")
  229. >>> import torch.distributed as dist
  230. >>> dist.init_process_group(backend="gloo|nccl")
  231. 3. In your training program, you can either use regular distributed functions
  232. or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
  233. training program uses GPUs for training and you would like to use
  234. :func:`torch.nn.parallel.DistributedDataParallel` module,
  235. here is how to configure it.
  236. ::
  237. local_rank = int(os.environ["LOCAL_RANK"])
  238. model = torch.nn.parallel.DistributedDataParallel(model,
  239. device_ids=[local_rank],
  240. output_device=local_rank)
  241. Please ensure that ``device_ids`` argument is set to be the only GPU device id
  242. that your code will be operating on. This is generally the local rank of the
  243. process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
  244. and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
  245. utility
  246. 4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
  247. checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
  248. for lost work.
  249. 5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
  250. nodes run the same number of local workers (per role).
  251. 6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
  252. different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
  253. ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
  254. 7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
  255. ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
  256. 8. It is recommended for your script to have the following structure:
  257. ::
  258. def main():
  259. load_checkpoint(checkpoint_path)
  260. initialize()
  261. train()
  262. def train():
  263. for batch in iter(dataset):
  264. train_step(batch)
  265. if should_checkpoint:
  266. save_checkpoint(checkpoint_path)
  267. 9. (Recommended) On worker errors, this tool will summarize the details of the error
  268. (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
  269. is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
  270. error summary print out, you must decorate your main entrypoint function in your
  271. training script as shown in the example below. If not decorated, then the summary
  272. will not include the traceback of the exception and will only contain the exitcode.
  273. For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
  274. ::
  275. from torch.distributed.elastic.multiprocessing.errors import record
  276. @record
  277. def main():
  278. # do train
  279. pass
  280. if __name__ == "__main__":
  281. main()
  282. """
  283. import logging
  284. import os
  285. import sys
  286. import uuid
  287. import importlib.metadata as metadata
  288. from argparse import REMAINDER, ArgumentParser
  289. from typing import Callable, List, Tuple, Type, Union, Optional, Set
  290. import torch
  291. from torch.distributed.argparse_util import check_env, env
  292. from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
  293. from torch.distributed.elastic.multiprocessing.errors import record
  294. from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
  295. from torch.distributed.elastic.utils import macros
  296. from torch.distributed.elastic.utils.logging import get_logger
  297. from torch.distributed.launcher.api import LaunchConfig, elastic_launch
  298. from torch.utils.backend_registration import _get_custom_mod_func
  299. logger = get_logger(__name__)
  300. def get_args_parser() -> ArgumentParser:
  301. """Parse the command line options."""
  302. parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
  303. #
  304. # Worker/node size related arguments.
  305. #
  306. parser.add_argument(
  307. "--nnodes",
  308. action=env,
  309. type=str,
  310. default="1:1",
  311. help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
  312. )
  313. parser.add_argument(
  314. "--nproc-per-node",
  315. "--nproc_per_node",
  316. action=env,
  317. type=str,
  318. default="1",
  319. help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
  320. )
  321. #
  322. # Rendezvous related arguments
  323. #
  324. parser.add_argument(
  325. "--rdzv-backend",
  326. "--rdzv_backend",
  327. action=env,
  328. type=str,
  329. default="static",
  330. help="Rendezvous backend.",
  331. )
  332. parser.add_argument(
  333. "--rdzv-endpoint",
  334. "--rdzv_endpoint",
  335. action=env,
  336. type=str,
  337. default="",
  338. help="Rendezvous backend endpoint; usually in form <host>:<port>.",
  339. )
  340. parser.add_argument(
  341. "--rdzv-id",
  342. "--rdzv_id",
  343. action=env,
  344. type=str,
  345. default="none",
  346. help="User-defined group id.",
  347. )
  348. parser.add_argument(
  349. "--rdzv-conf",
  350. "--rdzv_conf",
  351. action=env,
  352. type=str,
  353. default="",
  354. help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
  355. )
  356. parser.add_argument(
  357. "--standalone",
  358. action=check_env,
  359. help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
  360. "on a free port. Useful when launching single-node, multi-worker job. If specified "
  361. "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
  362. "are ignored.",
  363. )
  364. #
  365. # User-code launch related arguments.
  366. #
  367. parser.add_argument(
  368. "--max-restarts",
  369. "--max_restarts",
  370. action=env,
  371. type=int,
  372. default=0,
  373. help="Maximum number of worker group restarts before failing.",
  374. )
  375. parser.add_argument(
  376. "--monitor-interval",
  377. "--monitor_interval",
  378. action=env,
  379. type=float,
  380. default=0.1,
  381. help="Interval, in seconds, to monitor the state of workers.",
  382. )
  383. parser.add_argument(
  384. "--start-method",
  385. "--start_method",
  386. action=env,
  387. type=str,
  388. default="spawn",
  389. choices=["spawn", "fork", "forkserver"],
  390. help="Multiprocessing start method to use when creating workers.",
  391. )
  392. parser.add_argument(
  393. "--role",
  394. action=env,
  395. type=str,
  396. default="default",
  397. help="User-defined role for the workers.",
  398. )
  399. parser.add_argument(
  400. "-m",
  401. "--module",
  402. action=check_env,
  403. help="Change each process to interpret the launch script as a Python module, executing "
  404. "with the same behavior as 'python -m'.",
  405. )
  406. parser.add_argument(
  407. "--no-python",
  408. "--no_python",
  409. action=check_env,
  410. help="Skip prepending the training script with 'python' - just execute it directly. Useful "
  411. "when the script is not a Python script.",
  412. )
  413. parser.add_argument(
  414. "--run-path",
  415. "--run_path",
  416. action=check_env,
  417. help="Run the training script with runpy.run_path in the same interpreter."
  418. " Script must be provided as an abs path (e.g. /abs/path/script.py)."
  419. " Takes precedence over --no-python.",
  420. )
  421. parser.add_argument(
  422. "--log-dir",
  423. "--log_dir",
  424. action=env,
  425. type=str,
  426. default=None,
  427. help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
  428. "directory is re-used for multiple runs (a unique job-level sub-directory is created with "
  429. "rdzv_id as the prefix).",
  430. )
  431. parser.add_argument(
  432. "-r",
  433. "--redirects",
  434. action=env,
  435. type=str,
  436. default="0",
  437. help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
  438. "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
  439. "stderr for local rank 1).",
  440. )
  441. parser.add_argument(
  442. "-t",
  443. "--tee",
  444. action=env,
  445. type=str,
  446. default="0",
  447. help="Tee std streams into a log file and also to console (see --redirects for format).",
  448. )
  449. parser.add_argument(
  450. "--local-ranks-filter",
  451. "--local_ranks_filter",
  452. action=env,
  453. type=str,
  454. default="",
  455. help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
  456. "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
  457. "log files saved via --redirect or --tee",
  458. )
  459. #
  460. # Backwards compatible parameters with caffe2.distributed.launch.
  461. #
  462. parser.add_argument(
  463. "--node-rank",
  464. "--node_rank",
  465. type=int,
  466. action=env,
  467. default=0,
  468. help="Rank of the node for multi-node distributed training.",
  469. )
  470. parser.add_argument(
  471. "--master-addr",
  472. "--master_addr",
  473. default="127.0.0.1",
  474. type=str,
  475. action=env,
  476. help="Address of the master node (rank 0) that only used for static rendezvous. It should "
  477. "be either the IP address or the hostname of rank 0. For single node multi-proc training "
  478. "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
  479. "`[0:0:0:0:0:0:0:1]`.",
  480. )
  481. parser.add_argument(
  482. "--master-port",
  483. "--master_port",
  484. default=29500,
  485. type=int,
  486. action=env,
  487. help="Port on the master node (rank 0) to be used for communication during distributed "
  488. "training. It is only used for static rendezvous.",
  489. )
  490. parser.add_argument(
  491. "--local-addr",
  492. "--local_addr",
  493. default=None,
  494. type=str,
  495. action=env,
  496. help="Address of the local node. If specified, will use the given address for connection. "
  497. "Else, will look up the local node address instead. Else, it will be default to local "
  498. "machine's FQDN.",
  499. )
  500. parser.add_argument(
  501. "--logs-specs",
  502. "--logs_specs",
  503. default=None,
  504. type=str,
  505. help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
  506. "Can be used to override custom logging behavior.",
  507. )
  508. #
  509. # Positional arguments.
  510. #
  511. parser.add_argument(
  512. "training_script",
  513. type=str,
  514. help="Full path to the (single GPU) training program/script to be launched in parallel, "
  515. "followed by all the arguments for the training script.",
  516. )
  517. # Rest from the training program.
  518. parser.add_argument("training_script_args", nargs=REMAINDER)
  519. return parser
  520. def parse_args(args):
  521. parser = get_args_parser()
  522. return parser.parse_args(args)
  523. def parse_min_max_nnodes(nnodes: str):
  524. arr = nnodes.split(":")
  525. if len(arr) == 1:
  526. min_nodes = max_nodes = int(arr[0])
  527. elif len(arr) == 2:
  528. min_nodes = int(arr[0])
  529. max_nodes = int(arr[1])
  530. else:
  531. raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231
  532. return min_nodes, max_nodes
  533. def determine_local_world_size(nproc_per_node: str):
  534. try:
  535. logging.info("Using nproc_per_node=%s.", nproc_per_node)
  536. return int(nproc_per_node)
  537. except ValueError as e:
  538. if nproc_per_node == "cpu":
  539. num_proc = os.cpu_count()
  540. device_type = "cpu"
  541. elif nproc_per_node == "gpu":
  542. if not torch.cuda.is_available():
  543. raise ValueError("Cuda is not available.") from e
  544. device_type = "gpu"
  545. num_proc = torch.cuda.device_count()
  546. elif nproc_per_node == torch._C._get_privateuse1_backend_name():
  547. if not _get_custom_mod_func("is_available")():
  548. raise ValueError(f"{nproc_per_node} is not available.") from e
  549. device_type = nproc_per_node
  550. num_proc = _get_custom_mod_func("device_count")()
  551. elif nproc_per_node == "auto":
  552. if torch.cuda.is_available():
  553. num_proc = torch.cuda.device_count()
  554. device_type = "gpu"
  555. elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \
  556. _get_custom_mod_func("is_available")():
  557. num_proc = _get_custom_mod_func("device_count")()
  558. device_type = torch._C._get_privateuse1_backend_name()
  559. else:
  560. num_proc = os.cpu_count()
  561. device_type = "cpu"
  562. else:
  563. raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
  564. logger.info(
  565. "Using nproc_per_node=%s,"
  566. " setting to %s since the instance "
  567. "has %s %s",
  568. nproc_per_node, num_proc, os.cpu_count(), device_type
  569. )
  570. return num_proc
  571. def get_rdzv_endpoint(args):
  572. if args.rdzv_backend == "static" and not args.rdzv_endpoint:
  573. return f"{args.master_addr}:{args.master_port}" # noqa: E231
  574. return args.rdzv_endpoint
  575. def get_use_env(args) -> bool:
  576. """
  577. Retrieve ``use_env`` from the args.
  578. ``use_env`` is a legacy argument, if ``use_env`` is False, the
  579. ``--node-rank`` argument will be transferred to all worker processes.
  580. ``use_env`` is only used by the ``torch.distributed.launch`` and will
  581. be deprecated in future releases.
  582. """
  583. if not hasattr(args, "use_env"):
  584. return True
  585. return args.use_env
  586. def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
  587. """
  588. Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
  589. Provides plugin mechanism to provide custom implementation of LogsSpecs.
  590. Returns `DefaultLogsSpecs` when logs_spec_name is None.
  591. Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
  592. """
  593. logs_specs_cls = None
  594. if logs_specs_name is not None:
  595. eps = metadata.entry_points()
  596. if hasattr(eps, "select"): # >= 3.10
  597. group = eps.select(group="torchrun.logs_specs")
  598. if group.select(name=logs_specs_name):
  599. logs_specs_cls = group[logs_specs_name].load()
  600. elif specs := eps.get("torchrun.logs_specs"): # < 3.10
  601. if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
  602. logs_specs_cls = entrypoint_list[0].load()
  603. if logs_specs_cls is None:
  604. raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
  605. logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
  606. else:
  607. logs_specs_cls = DefaultLogsSpecs
  608. return logs_specs_cls
  609. def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
  610. # If ``args`` not passed, defaults to ``sys.argv[:1]``
  611. min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
  612. assert 0 < min_nodes <= max_nodes
  613. assert args.max_restarts >= 0
  614. if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint:
  615. logger.warning(
  616. "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
  617. "is not specified."
  618. )
  619. nproc_per_node = determine_local_world_size(args.nproc_per_node)
  620. if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
  621. omp_num_threads = 1
  622. logger.warning(
  623. "\n*****************************************\n"
  624. "Setting OMP_NUM_THREADS environment variable for each process to be "
  625. "%s in default, to avoid your system being overloaded, "
  626. "please further tune the variable for optimal performance in "
  627. "your application as needed. \n"
  628. "*****************************************",
  629. omp_num_threads
  630. )
  631. # This env variable will be passed down to the subprocesses
  632. os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
  633. log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
  634. rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
  635. if args.rdzv_backend == "static":
  636. rdzv_configs["rank"] = args.node_rank
  637. rdzv_endpoint = get_rdzv_endpoint(args)
  638. ranks: Optional[Set[int]] = None
  639. if args.local_ranks_filter:
  640. try:
  641. ranks = set(map(int, args.local_ranks_filter.split(",")))
  642. assert ranks
  643. except Exception as e:
  644. raise ValueError(
  645. "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
  646. ) from e
  647. logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
  648. logs_specs = logs_specs_cls(
  649. log_dir=args.log_dir,
  650. redirects=Std.from_str(args.redirects),
  651. tee=Std.from_str(args.tee),
  652. local_ranks_filter=ranks,
  653. )
  654. config = LaunchConfig(
  655. min_nodes=min_nodes,
  656. max_nodes=max_nodes,
  657. nproc_per_node=nproc_per_node,
  658. run_id=args.rdzv_id,
  659. role=args.role,
  660. rdzv_endpoint=rdzv_endpoint,
  661. rdzv_backend=args.rdzv_backend,
  662. rdzv_configs=rdzv_configs,
  663. max_restarts=args.max_restarts,
  664. monitor_interval=args.monitor_interval,
  665. start_method=args.start_method,
  666. log_line_prefix_template=log_line_prefix_template,
  667. local_addr=args.local_addr,
  668. logs_specs=logs_specs,
  669. )
  670. with_python = not args.no_python
  671. cmd: Union[Callable, str]
  672. cmd_args = []
  673. use_env = get_use_env(args)
  674. if args.run_path:
  675. cmd = run_script_path
  676. cmd_args.append(args.training_script)
  677. else:
  678. if with_python:
  679. cmd = os.getenv("PYTHON_EXEC", sys.executable)
  680. cmd_args.append("-u")
  681. if args.module:
  682. cmd_args.append("-m")
  683. cmd_args.append(args.training_script)
  684. else:
  685. if args.module:
  686. raise ValueError(
  687. "Don't use both the '--no-python' flag"
  688. " and the '--module' flag at the same time."
  689. )
  690. cmd = args.training_script
  691. if not use_env:
  692. cmd_args.append(f"--local-rank={macros.local_rank}")
  693. cmd_args.extend(args.training_script_args)
  694. return config, cmd, cmd_args
  695. def run_script_path(training_script: str, *training_script_args: str):
  696. """
  697. Run the provided `training_script` from within this interpreter.
  698. Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
  699. """
  700. import runpy
  701. import sys
  702. sys.argv = [training_script] + [*training_script_args]
  703. runpy.run_path(sys.argv[0], run_name="__main__")
  704. def run(args):
  705. if args.standalone:
  706. args.rdzv_backend = "c10d"
  707. args.rdzv_endpoint = "localhost:0"
  708. args.rdzv_id = str(uuid.uuid4())
  709. logger.info(
  710. "\n**************************************\n"
  711. "Rendezvous info:\n"
  712. "--rdzv-backend=%s "
  713. "--rdzv-endpoint=%s "
  714. "--rdzv-id=%s\n"
  715. "**************************************\n",
  716. args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id
  717. )
  718. config, cmd, cmd_args = config_from_args(args)
  719. elastic_launch(
  720. config=config,
  721. entrypoint=cmd,
  722. )(*cmd_args)
  723. @record
  724. def main(args=None):
  725. args = parse_args(args)
  726. run(args)
  727. if __name__ == "__main__":
  728. main()