launch.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # mypy: allow-untyped-defs
  2. r"""
  3. Module ``torch.distributed.launch``.
  4. ``torch.distributed.launch`` is a module that spawns up multiple distributed
  5. training processes on each of the training nodes.
  6. .. warning::
  7. This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`.
  8. The utility can be used for single-node distributed training, in which one or
  9. more processes per node will be spawned. The utility can be used for either
  10. CPU training or GPU training. If the utility is used for GPU training,
  11. each distributed process will be operating on a single GPU. This can achieve
  12. well-improved single-node training performance. It can also be used in
  13. multi-node distributed training, by spawning up multiple processes on each node
  14. for well-improved multi-node distributed training performance as well.
  15. This will especially be beneficial for systems with multiple Infiniband
  16. interfaces that have direct-GPU support, since all of them can be utilized for
  17. aggregated communication bandwidth.
  18. In both cases of single-node distributed training or multi-node distributed
  19. training, this utility will launch the given number of processes per node
  20. (``--nproc-per-node``). If used for GPU training, this number needs to be less
  21. or equal to the number of GPUs on the current system (``nproc_per_node``),
  22. and each process will be operating on a single GPU from *GPU 0 to
  23. GPU (nproc_per_node - 1)*.
  24. **How to use this module:**
  25. 1. Single-Node multi-process distributed training
  26. ::
  27. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  28. YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
  29. arguments of your training script)
  30. 2. Multi-Node multi-process distributed training: (e.g. two nodes)
  31. Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
  32. ::
  33. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  34. --nnodes=2 --node-rank=0 --master-addr="192.168.1.1"
  35. --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
  36. and all other arguments of your training script)
  37. Node 2:
  38. ::
  39. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  40. --nnodes=2 --node-rank=1 --master-addr="192.168.1.1"
  41. --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
  42. and all other arguments of your training script)
  43. 3. To look up what optional arguments this module offers:
  44. ::
  45. python -m torch.distributed.launch --help
  46. **Important Notices:**
  47. 1. This utility and multi-process distributed (single-node or
  48. multi-node) GPU training currently only achieves the best performance using
  49. the NCCL distributed backend. Thus NCCL backend is the recommended backend to
  50. use for GPU training.
  51. 2. In your training program, you must parse the command-line argument:
  52. ``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
  53. If your training program uses GPUs, you should ensure that your code only
  54. runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
  55. Parsing the local_rank argument
  56. ::
  57. >>> # xdoctest: +SKIP
  58. >>> import argparse
  59. >>> parser = argparse.ArgumentParser()
  60. >>> parser.add_argument("--local-rank", "--local_rank", type=int)
  61. >>> args = parser.parse_args()
  62. Set your device to local rank using either
  63. ::
  64. >>> torch.cuda.set_device(args.local_rank) # before your code runs
  65. or
  66. ::
  67. >>> with torch.cuda.device(args.local_rank):
  68. >>> # your code to run
  69. >>> ...
  70. .. versionchanged:: 2.0.0
  71. The launcher will passes the ``--local-rank=<rank>`` argument to your script.
  72. From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
  73. previously used underscored ``--local_rank``.
  74. For backward compatibility, it may be necessary for users to handle both
  75. cases in their argument parsing code. This means including both ``"--local-rank"``
  76. and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
  77. provided, the launcher will trigger an error: "error: unrecognized arguments:
  78. --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
  79. including ``"--local-rank"`` should be sufficient.
  80. 3. In your training program, you are supposed to call the following function
  81. at the beginning to start the distributed backend. It is strongly recommended
  82. that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work,
  83. but ``env://`` is the one that is officially supported by this module.
  84. ::
  85. >>> torch.distributed.init_process_group(backend='YOUR BACKEND',
  86. >>> init_method='env://')
  87. 4. In your training program, you can either use regular distributed functions
  88. or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
  89. training program uses GPUs for training and you would like to use
  90. :func:`torch.nn.parallel.DistributedDataParallel` module,
  91. here is how to configure it.
  92. ::
  93. >>> model = torch.nn.parallel.DistributedDataParallel(model,
  94. >>> device_ids=[args.local_rank],
  95. >>> output_device=args.local_rank)
  96. Please ensure that ``device_ids`` argument is set to be the only GPU device id
  97. that your code will be operating on. This is generally the local rank of the
  98. process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
  99. and ``output_device`` needs to be ``args.local_rank`` in order to use this
  100. utility
  101. 5. Another way to pass ``local_rank`` to the subprocesses via environment variable
  102. ``LOCAL_RANK``. This behavior is enabled when you launch the script with
  103. ``--use-env=True``. You must adjust the subprocess example above to replace
  104. ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
  105. will not pass ``--local-rank`` when you specify this flag.
  106. .. warning::
  107. ``local_rank`` is NOT globally unique: it is only unique per process
  108. on a machine. Thus, don't use it to decide if you should, e.g.,
  109. write to a networked filesystem. See
  110. https://github.com/pytorch/pytorch/issues/12042 for an example of
  111. how things can go wrong if you don't do this correctly.
  112. """
  113. from typing_extensions import deprecated as _deprecated
  114. from torch.distributed.run import get_args_parser, run
  115. def parse_args(args):
  116. parser = get_args_parser()
  117. parser.add_argument(
  118. "--use-env",
  119. "--use_env",
  120. default=False,
  121. action="store_true",
  122. help="Use environment variable to pass "
  123. "'local rank'. For legacy reasons, the default value is False. "
  124. "If set to True, the script will not pass "
  125. "--local-rank as argument, and will instead set LOCAL_RANK.",
  126. )
  127. return parser.parse_args(args)
  128. def launch(args):
  129. if args.no_python and not args.use_env:
  130. raise ValueError(
  131. "When using the '--no-python' flag,"
  132. " you must also set the '--use-env' flag."
  133. )
  134. run(args)
  135. @_deprecated(
  136. "The module torch.distributed.launch is deprecated\n"
  137. "and will be removed in future. Use torchrun.\n"
  138. "Note that --use-env is set by default in torchrun.\n"
  139. "If your script expects `--local-rank` argument to be set, please\n"
  140. "change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
  141. "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
  142. "further instructions\n",
  143. category=FutureWarning,
  144. )
  145. def main(args=None):
  146. args = parse_args(args)
  147. launch(args)
  148. if __name__ == "__main__":
  149. main()