data_parallel.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # mypy: allow-untyped-defs
  2. import operator
  3. import torch
  4. import warnings
  5. from itertools import chain
  6. from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
  7. from ..modules import Module
  8. from .scatter_gather import scatter_kwargs, gather
  9. from .replicate import replicate
  10. from .parallel_apply import parallel_apply
  11. from torch._utils import (
  12. _get_all_device_indices,
  13. _get_available_device_type,
  14. _get_device_index,
  15. _get_devices_properties
  16. )
  17. __all__ = ['DataParallel', 'data_parallel']
  18. def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
  19. imbalance_warn = """
  20. There is an imbalance between your GPUs. You may want to exclude GPU {} which
  21. has less than 75% of the memory or cores of GPU {}. You can do so by setting
  22. the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
  23. environment variable."""
  24. device_ids = [_get_device_index(x, True) for x in device_ids]
  25. dev_props = _get_devices_properties(device_ids)
  26. def warn_imbalance(get_prop):
  27. values = [get_prop(props) for props in dev_props]
  28. min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
  29. max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
  30. if min_val / max_val < 0.75:
  31. warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
  32. return True
  33. return False
  34. if warn_imbalance(lambda props: props.total_memory):
  35. return
  36. if warn_imbalance(lambda props: props.multi_processor_count):
  37. return
  38. T = TypeVar("T", bound=Module)
  39. class DataParallel(Module, Generic[T]):
  40. r"""Implements data parallelism at the module level.
  41. This container parallelizes the application of the given :attr:`module` by
  42. splitting the input across the specified devices by chunking in the batch
  43. dimension (other objects will be copied once per device). In the forward
  44. pass, the module is replicated on each device, and each replica handles a
  45. portion of the input. During the backwards pass, gradients from each replica
  46. are summed into the original module.
  47. The batch size should be larger than the number of GPUs used.
  48. .. warning::
  49. It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
  50. instead of this class, to do multi-GPU training, even if there is only a single
  51. node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
  52. Arbitrary positional and keyword inputs are allowed to be passed into
  53. DataParallel but some types are specially handled. tensors will be
  54. **scattered** on dim specified (default 0). tuple, list and dict types will
  55. be shallow copied. The other types will be shared among different threads
  56. and can be corrupted if written to in the model's forward pass.
  57. The parallelized :attr:`module` must have its parameters and buffers on
  58. ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
  59. module.
  60. .. warning::
  61. In each forward, :attr:`module` is **replicated** on each device, so any
  62. updates to the running module in ``forward`` will be lost. For example,
  63. if :attr:`module` has a counter attribute that is incremented in each
  64. ``forward``, it will always stay at the initial value because the update
  65. is done on the replicas which are destroyed after ``forward``. However,
  66. :class:`~torch.nn.DataParallel` guarantees that the replica on
  67. ``device[0]`` will have its parameters and buffers sharing storage with
  68. the base parallelized :attr:`module`. So **in-place** updates to the
  69. parameters or buffers on ``device[0]`` will be recorded. E.g.,
  70. :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
  71. rely on this behavior to update the buffers.
  72. .. warning::
  73. Forward and backward hooks defined on :attr:`module` and its submodules
  74. will be invoked ``len(device_ids)`` times, each with inputs located on
  75. a particular device. Particularly, the hooks are only guaranteed to be
  76. executed in correct order with respect to operations on corresponding
  77. devices. For example, it is not guaranteed that hooks set via
  78. :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
  79. `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
  80. that each such hook be executed before the corresponding
  81. :meth:`~torch.nn.Module.forward` call of that device.
  82. .. warning::
  83. When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
  84. :func:`forward`, this wrapper will return a vector of length equal to
  85. number of devices used in data parallelism, containing the result from
  86. each device.
  87. .. note::
  88. There is a subtlety in using the
  89. ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
  90. :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
  91. See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
  92. details.
  93. Args:
  94. module (Module): module to be parallelized
  95. device_ids (list of int or torch.device): CUDA devices (default: all devices)
  96. output_device (int or torch.device): device location of output (default: device_ids[0])
  97. Attributes:
  98. module (Module): the module to be parallelized
  99. Example::
  100. >>> # xdoctest: +SKIP
  101. >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
  102. >>> output = net(input_var) # input_var can be on any device, including CPU
  103. """
  104. # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
  105. def __init__(
  106. self,
  107. module: T,
  108. device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
  109. output_device: Optional[Union[int, torch.device]] = None,
  110. dim: int = 0,
  111. ) -> None:
  112. super().__init__()
  113. torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
  114. device_type = _get_available_device_type()
  115. if device_type is None:
  116. self.module = module
  117. self.device_ids = []
  118. return
  119. if device_ids is None:
  120. device_ids = _get_all_device_indices()
  121. if device_ids is None:
  122. raise RuntimeError("no available devices were found")
  123. if output_device is None:
  124. output_device = device_ids[0]
  125. self.dim = dim
  126. self.module = module
  127. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  128. self.output_device = _get_device_index(output_device, True)
  129. self.src_device_obj = torch.device(device_type, self.device_ids[0])
  130. if device_type == "cuda":
  131. _check_balance(self.device_ids)
  132. if len(self.device_ids) == 1:
  133. self.module.to(self.src_device_obj)
  134. def forward(self, *inputs: Any, **kwargs: Any) -> Any:
  135. with torch.autograd.profiler.record_function("DataParallel.forward"):
  136. if not self.device_ids:
  137. return self.module(*inputs, **kwargs)
  138. for t in chain(self.module.parameters(), self.module.buffers()):
  139. if t.device != self.src_device_obj:
  140. raise RuntimeError("module must have its parameters and buffers "
  141. f"on device {self.src_device_obj} (device_ids[0]) but found one of "
  142. f"them on device: {t.device}")
  143. inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
  144. # for forward function without any inputs, empty list and dict will be created
  145. # so the module can be executed on one device which is the first one in device_ids
  146. if not inputs and not module_kwargs:
  147. inputs = ((),)
  148. module_kwargs = ({},)
  149. if len(self.device_ids) == 1:
  150. return self.module(*inputs[0], **module_kwargs[0])
  151. replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  152. outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  153. return self.gather(outputs, self.output_device)
  154. def replicate(self, module: T, device_ids: Sequence[Union[int, torch.device]]) -> List[T]:
  155. return replicate(module, device_ids, not torch.is_grad_enabled())
  156. def scatter(
  157. self,
  158. inputs: Tuple[Any, ...],
  159. kwargs: Optional[Dict[str, Any]],
  160. device_ids: Sequence[Union[int, torch.device]],
  161. ) -> Any:
  162. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  163. def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]:
  164. return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  165. def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any:
  166. return gather(outputs, output_device, dim=self.dim)
  167. def data_parallel(
  168. module: Module,
  169. inputs: Any,
  170. device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
  171. output_device: Optional[Union[int, torch.device]] = None,
  172. dim: int = 0,
  173. module_kwargs: Optional[Any] = None,
  174. ) -> torch.Tensor:
  175. r"""Evaluate module(input) in parallel across the GPUs given in device_ids.
  176. This is the functional version of the DataParallel module.
  177. Args:
  178. module (Module): the module to evaluate in parallel
  179. inputs (Tensor): inputs to the module
  180. device_ids (list of int or torch.device): GPU ids on which to replicate module
  181. output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
  182. (default: device_ids[0])
  183. Returns:
  184. a Tensor containing the result of module(input) located on
  185. output_device
  186. """
  187. if not isinstance(inputs, tuple):
  188. inputs = (inputs,) if inputs is not None else ()
  189. device_type = _get_available_device_type()
  190. if device_type is None:
  191. raise RuntimeError("device type could not be determined")
  192. if device_ids is None:
  193. device_ids = _get_all_device_indices()
  194. if device_ids is None:
  195. raise RuntimeError("no available devices were found")
  196. if output_device is None:
  197. output_device = device_ids[0]
  198. device_ids = [_get_device_index(x, True) for x in device_ids]
  199. output_device = _get_device_index(output_device, True)
  200. src_device_obj = torch.device(device_type, device_ids[0])
  201. for t in chain(module.parameters(), module.buffers()):
  202. if t.device != src_device_obj:
  203. raise RuntimeError("module must have its parameters and buffers "
  204. f"on device {src_device_obj} (device_ids[0]) but found one of "
  205. f"them on device: {t.device}")
  206. inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
  207. # for module without any inputs, empty list and dict will be created
  208. # so the module can be executed on one device which is the first one in device_ids
  209. if not inputs and not module_kwargs:
  210. inputs = ((),)
  211. module_kwargs = ({},)
  212. assert module_kwargs is not None
  213. if len(device_ids) == 1:
  214. return module(*inputs[0], **module_kwargs[0])
  215. used_device_ids = device_ids[:len(inputs)]
  216. replicas = replicate(module, used_device_ids)
  217. outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
  218. return gather(outputs, output_device, dim)