optimizer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # mypy: allow-untyped-defs
  2. import logging
  3. from collections import defaultdict
  4. from threading import Lock
  5. from typing import List, Optional
  6. import torch
  7. import torch.distributed.autograd as dist_autograd
  8. import torch.distributed.rpc as rpc
  9. import torch.jit as jit
  10. import torch.nn as nn
  11. from torch import Tensor
  12. from torch.distributed.rpc import RRef
  13. from .utils import functional_optim_map
  14. __all__ = ["DistributedOptimizer"]
  15. logger = logging.getLogger(__name__)
  16. # XXX: we define a _ScriptModuleOptimizer here to explicitly
  17. # compile the FunctionalOptimizer class into TorchScript
  18. # This is because ScriptClass instance still lives in
  19. # python unless you explicitly compile it as an attribute
  20. # in ScriptModule or pass it to a ScriptFunction
  21. # _ScriptLocalOptimizerInterface serves as a common
  22. # interface type for Optimizer ScriptModules.
  23. #
  24. # TODO (wanchaol): remove this once we added TorchScript
  25. # class reference semantics
  26. @jit.interface
  27. class _ScriptLocalOptimizerInterface:
  28. def step(self, autograd_ctx_id: int) -> None:
  29. pass
  30. class _ScriptLocalOptimizer(nn.Module):
  31. # TorchScript does not support multithread concurrent compiling.
  32. # request_callback might invoke concurrent compiling, so we
  33. # serialize the compiling with a lock
  34. compile_lock = Lock()
  35. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  36. super().__init__()
  37. self._local_params = [rref.local_value() for rref in local_params_rref]
  38. self.optim = optim_cls(self._local_params, *args, **kwargs)
  39. @jit.export
  40. def step(self, autograd_ctx_id: int):
  41. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  42. # apply functional optimizer step with a list of gradients
  43. grads: List[Optional[Tensor]] = [
  44. all_local_grads[p] if p in all_local_grads else None
  45. for p in self._local_params
  46. ]
  47. self.optim.step(grads)
  48. # TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
  49. # we have converted all to functional optimizer in distributed.optim
  50. class _LocalOptimizer:
  51. # Ideally we would only need to share a lock for instances of
  52. # _LocalOptimizer that deal with the same parameters. We are
  53. # making a simplifying assumption here that if there is more
  54. # than one instance of _LocalOptimizer per worker, they will
  55. # be optimizing the same parameters (e.g. each data parallel
  56. # trainer will create its own instance of _LocalOptimizer but
  57. # they will all optimize the same parameters on each worker)
  58. global_lock = Lock()
  59. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  60. self._local_params = [rref.local_value() for rref in local_params_rref]
  61. self.optim = optim_cls(self._local_params, *args, **kwargs)
  62. def step(self, autograd_ctx_id):
  63. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  64. with _LocalOptimizer.global_lock:
  65. for param, grad in all_local_grads.items():
  66. param.grad = grad
  67. self.optim.step()
  68. def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  69. return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
  70. def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
  71. local_optim = local_optim_rref.local_value()
  72. local_optim.step(autograd_ctx_id)
  73. # new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
  74. def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  75. optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
  76. with _ScriptLocalOptimizer.compile_lock:
  77. script_optim = jit.script(optim)
  78. return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
  79. @jit.script
  80. def _script_local_optimizer_step(
  81. local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int
  82. ) -> None:
  83. local_optim = local_optim_rref.local_value()
  84. local_optim.step(autograd_ctx_id)
  85. def _wait_for_all(rpc_futs):
  86. # TODO: improve error propagation
  87. exception = None
  88. results = []
  89. for fut in rpc_futs:
  90. try:
  91. results.append(fut.wait())
  92. except Exception as e:
  93. results.append(e)
  94. exception = e
  95. if exception is not None:
  96. raise exception
  97. return results
  98. class DistributedOptimizer:
  99. """
  100. DistributedOptimizer takes remote references to parameters scattered
  101. across workers and applies the given optimizer locally for each parameter.
  102. This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
  103. to retrieve the gradients for specific parameters.
  104. Concurrent calls to
  105. :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
  106. either from the same or different clients, will
  107. be serialized on each worker -- as each worker's optimizer can only work
  108. on one set of gradients at a time. However, there is no guarantee that
  109. the full forward-backward-optimizer sequence will execute for one client
  110. at a time. This means that the gradients being applied may not correspond
  111. to the latest forward pass executed on a given worker. Also, there is no
  112. guaranteed ordering across workers.
  113. `DistributedOptimizer` creates the local optimizer with TorchScript enabled
  114. by default, so that optimizer updates are not blocked by the Python Global
  115. Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
  116. Model Parallel). This feature is currently enabled for most optimizers. You
  117. can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
  118. for your own custom optimizers.
  119. Args:
  120. optimizer_class (optim.Optimizer): the class of optimizer to
  121. instantiate on each worker.
  122. params_rref (list[RRef]): list of RRefs to local or remote parameters
  123. to optimize.
  124. args: arguments to pass to the optimizer constructor on each worker.
  125. kwargs: arguments to pass to the optimizer constructor on each worker.
  126. Example::
  127. >>> # xdoctest: +SKIP("distributed")
  128. >>> import torch.distributed.autograd as dist_autograd
  129. >>> import torch.distributed.rpc as rpc
  130. >>> from torch import optim
  131. >>> from torch.distributed.optim import DistributedOptimizer
  132. >>>
  133. >>> with dist_autograd.context() as context_id:
  134. >>> # Forward pass.
  135. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  136. >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  137. >>> loss = rref1.to_here() + rref2.to_here()
  138. >>>
  139. >>> # Backward pass.
  140. >>> dist_autograd.backward(context_id, [loss.sum()])
  141. >>>
  142. >>> # Optimizer.
  143. >>> dist_optim = DistributedOptimizer(
  144. >>> optim.SGD,
  145. >>> [rref1, rref2],
  146. >>> lr=0.05,
  147. >>> )
  148. >>> dist_optim.step(context_id)
  149. __ https://github.com/pytorch/tutorials/pull/1465
  150. """
  151. def __init__(self, optimizer_class, params_rref, *args, **kwargs):
  152. torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
  153. per_worker_params_rref = defaultdict(list)
  154. for param in params_rref:
  155. per_worker_params_rref[param.owner()].append(param)
  156. if optimizer_class in functional_optim_map and jit._state._enabled:
  157. optim_ctor = functional_optim_map.get(optimizer_class)
  158. else:
  159. optim_ctor = optimizer_class
  160. self.is_functional_optim = optim_ctor != optimizer_class
  161. if self.is_functional_optim:
  162. optimizer_new_func = _new_script_local_optimizer
  163. else:
  164. logger.warning(
  165. "Creating the optimizer %s without TorchScript support, "
  166. "this might result in slow computation time in multithreading environment"
  167. "(i.e. Distributed Model Parallel training on CPU) due to the Python's "
  168. "Global Interpreter Lock (GIL). Please file an issue if you need this "
  169. "optimizer in TorchScript. ",
  170. optimizer_class
  171. )
  172. optimizer_new_func = _new_local_optimizer
  173. remote_optim_futs = []
  174. for worker, param_rrefs in per_worker_params_rref.items():
  175. remote_optim_rref_fut = rpc.rpc_async(
  176. worker,
  177. optimizer_new_func,
  178. args=(optim_ctor, param_rrefs) + args,
  179. kwargs=kwargs,
  180. )
  181. remote_optim_futs.append(remote_optim_rref_fut)
  182. self.remote_optimizers = _wait_for_all(remote_optim_futs)
  183. def step(self, context_id):
  184. """
  185. Performs a single optimization step.
  186. This will call :meth:`torch.optim.Optimizer.step` on each worker
  187. containing parameters to be optimized, and will block until all workers
  188. return. The provided ``context_id`` will be used to retrieve the
  189. corresponding :class:`~torch.distributed.autograd.context` that
  190. contains the gradients that should be applied to the parameters.
  191. Args:
  192. context_id: the autograd context id for which we should run the
  193. optimizer step.
  194. """
  195. dist_autograd._is_valid_context(context_id)
  196. optimizer_step_func = (
  197. _script_local_optimizer_step
  198. if self.is_functional_optim
  199. else _local_optimizer_step
  200. )
  201. rpc_futs = []
  202. for optimizer in self.remote_optimizers:
  203. rpc_futs.append(
  204. rpc.rpc_async(
  205. optimizer.owner(),
  206. optimizer_step_func,
  207. args=(optimizer, context_id),
  208. )
  209. )
  210. _wait_for_all(rpc_futs)