functional_sgd.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # mypy: allow-untyped-defs
  2. from typing import Dict, List, Optional
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. __all__: List[str] = []
  7. # Define a TorchScript compatible Functional SGD Optimizer
  8. # where we use these optimizer in a functional way.
  9. # Instead of using the `param.grad` when updating parameters,
  10. # we explicitly allow the distributed optimizer pass gradients to
  11. # the `step` function. In this way, we could separate the gradients
  12. # and parameters and allow multithreaded trainer to update the
  13. # parameters without data traces on accumulating to the same .grad.
  14. # NOTE: This should be only used by distributed optimizer internals
  15. # and not meant to expose to the user.
  16. @torch.jit.script
  17. class _FunctionalSGD:
  18. def __init__(
  19. self,
  20. params: List[Tensor],
  21. lr: float = 1e-2,
  22. momentum: float = 0.0,
  23. dampening: float = 0.0,
  24. weight_decay: float = 0.0,
  25. nesterov: bool = False,
  26. maximize: bool = False,
  27. foreach: bool = False,
  28. fused: bool = False,
  29. _allow_empty_param_list: bool = False,
  30. ):
  31. self.defaults = {
  32. "lr": lr,
  33. "momentum": momentum,
  34. "dampening": dampening,
  35. "weight_decay": weight_decay,
  36. }
  37. self.nesterov = nesterov
  38. self.maximize = maximize
  39. self.foreach = foreach
  40. self.fused = fused
  41. self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
  42. if len(params) == 0 and not _allow_empty_param_list:
  43. raise ValueError("optimizer got an empty parameter list")
  44. # NOTE: we only have one param_group and don't allow user to add additional
  45. # param group as it's not a common use case.
  46. self.param_group = {"params": params}
  47. def step_param(self, param: Tensor, grad: Optional[Tensor]):
  48. """Similar to self.step, but operates on a single parameter and
  49. its gradient.
  50. """
  51. # TODO: Once step_param interface is robust, refactor step to call
  52. # step param on each param.
  53. weight_decay = self.defaults["weight_decay"]
  54. momentum = self.defaults["momentum"]
  55. dampening = self.defaults["dampening"]
  56. lr = self.defaults["lr"]
  57. params = [param]
  58. momentum_buffer_list: List[Optional[Tensor]] = []
  59. grads = []
  60. has_sparse_grad = False
  61. if grad is not None:
  62. grads.append(grad)
  63. if grad.is_sparse:
  64. has_sparse_grad = True
  65. if param not in self.state:
  66. self.state[param] = {}
  67. state = self.state[param]
  68. if "momentum_buffer" not in state:
  69. momentum_buffer_list.append(None)
  70. else:
  71. momentum_buffer_list.append(state["momentum_buffer"])
  72. with torch.no_grad():
  73. F.sgd(
  74. params,
  75. grads,
  76. momentum_buffer_list,
  77. weight_decay=weight_decay,
  78. momentum=momentum,
  79. lr=lr,
  80. dampening=dampening,
  81. nesterov=self.nesterov,
  82. maximize=self.maximize,
  83. has_sparse_grad=has_sparse_grad,
  84. foreach=self.foreach,
  85. fused=self.fused,
  86. grad_scale=None,
  87. found_inf=None,
  88. )
  89. # update momentum_buffer in state
  90. state = self.state[param]
  91. momentum_buffer = momentum_buffer_list[0]
  92. if momentum_buffer is not None:
  93. state["momentum_buffer"] = momentum_buffer
  94. def step(self, gradients: List[Optional[Tensor]]):
  95. params = self.param_group["params"]
  96. params_with_grad = []
  97. grads = []
  98. momentum_buffer_list: List[Optional[Tensor]] = []
  99. lr = self.defaults["lr"]
  100. weight_decay = self.defaults["weight_decay"]
  101. momentum = self.defaults["momentum"]
  102. dampening = self.defaults["dampening"]
  103. if len(params) != len(gradients):
  104. raise ValueError(
  105. "the gradients passed in does not equal to the size of the parameters!"
  106. + f"Params length: {len(params)}. "
  107. + f"Gradients length: {len(gradients)}"
  108. )
  109. has_sparse_grad = False
  110. for param, gradient in zip(params, gradients):
  111. if gradient is not None:
  112. params_with_grad.append(param)
  113. grads.append(gradient)
  114. if gradient.is_sparse:
  115. has_sparse_grad = True
  116. if param not in self.state:
  117. self.state[param] = {}
  118. state = self.state[param]
  119. if "momentum_buffer" not in state:
  120. momentum_buffer_list.append(None)
  121. else:
  122. momentum_buffer_list.append(state["momentum_buffer"])
  123. with torch.no_grad():
  124. F.sgd(
  125. params_with_grad,
  126. grads,
  127. momentum_buffer_list,
  128. weight_decay=weight_decay,
  129. momentum=momentum,
  130. lr=lr,
  131. dampening=dampening,
  132. nesterov=self.nesterov,
  133. maximize=self.maximize,
  134. has_sparse_grad=has_sparse_grad,
  135. foreach=self.foreach,
  136. fused=self.fused,
  137. grad_scale=None,
  138. found_inf=None,
  139. )
  140. # update momentum_buffers in state
  141. for i, p in enumerate(params_with_grad):
  142. state = self.state[p]
  143. momentum_buffer = momentum_buffer_list[i]
  144. if momentum_buffer is not None:
  145. state["momentum_buffer"] = momentum_buffer