functional_adam.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # mypy: allow-untyped-defs
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. __all__: List[str] = []
  7. # Define a TorchScript compatible Functional Adam Optimizer
  8. # where we use these optimizer in a functional way.
  9. # Instead of using the `param.grad` when updating parameters,
  10. # we explicitly allow the distributed optimizer pass gradients to
  11. # the `step` function. In this way, we could separate the gradients
  12. # and parameters and allow multithreaded trainer to update the
  13. # parameters without data traces on accumulating to the same .grad.
  14. # NOTE: This should be only used by distributed optimizer internals
  15. # and not meant to expose to the user.
  16. @torch.jit.script
  17. class _FunctionalAdam:
  18. def __init__(
  19. self,
  20. params: List[Tensor],
  21. lr: float = 1e-3,
  22. betas: Tuple[float, float] = (0.9, 0.999),
  23. eps: float = 1e-8,
  24. weight_decay: float = 0.0,
  25. amsgrad: bool = False,
  26. maximize: bool = False,
  27. foreach: bool = False,
  28. fused: bool = False,
  29. _allow_empty_param_list: bool = False,
  30. ):
  31. if not 0.0 <= lr:
  32. raise ValueError(f"Invalid learning rate: {lr}")
  33. if not 0.0 <= eps:
  34. raise ValueError(f"Invalid epsilon value: {eps}")
  35. if not 0.0 <= betas[0] < 1.0:
  36. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  37. if not 0.0 <= betas[1] < 1.0:
  38. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  39. if not 0.0 <= weight_decay:
  40. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  41. self.defaults = {
  42. "lr": lr,
  43. "eps": eps,
  44. "beta1": betas[0],
  45. "beta2": betas[1],
  46. "weight_decay": weight_decay,
  47. }
  48. self.amsgrad = amsgrad
  49. self.maximize = maximize
  50. self.foreach = foreach
  51. self.fused = fused
  52. self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
  53. if len(params) == 0 and not _allow_empty_param_list:
  54. raise ValueError("optimizer got an empty parameter list")
  55. # NOTE: we only have one param_group and don't allow user to add additional
  56. # param group as it's not a common use case.
  57. self.param_group = {"params": params}
  58. def step_param(self, param: Tensor, grad: Optional[Tensor]):
  59. """
  60. Similar to step, but operates on a single parameter and optionally a
  61. gradient tensor.
  62. """
  63. params_with_grad = []
  64. grads = []
  65. exp_avgs = []
  66. exp_avg_sqs = []
  67. max_exp_avg_sqs = []
  68. state_steps: List[Tensor] = []
  69. has_complex = torch.is_complex(param)
  70. if grad is not None:
  71. params_with_grad.append(param)
  72. grads.append(grad)
  73. if param not in self.state:
  74. self.state[param] = {}
  75. state = self.state[param]
  76. state["step"] = torch.tensor(0.0)
  77. state["exp_avg"] = torch.zeros_like(
  78. param, memory_format=torch.preserve_format
  79. )
  80. state["exp_avg_sq"] = torch.zeros_like(
  81. param, memory_format=torch.preserve_format
  82. )
  83. if self.amsgrad:
  84. state["max_exp_avg_sq"] = torch.zeros_like(
  85. param, memory_format=torch.preserve_format
  86. )
  87. state = self.state[param]
  88. exp_avgs.append(state["exp_avg"])
  89. exp_avg_sqs.append(state["exp_avg_sq"])
  90. if self.amsgrad:
  91. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  92. state_steps.append(state["step"])
  93. with torch.no_grad():
  94. F.adam(
  95. params_with_grad,
  96. grads,
  97. exp_avgs,
  98. exp_avg_sqs,
  99. max_exp_avg_sqs,
  100. state_steps,
  101. amsgrad=self.amsgrad,
  102. has_complex=has_complex,
  103. maximize=self.maximize,
  104. beta1=self.defaults["beta1"],
  105. beta2=self.defaults["beta2"],
  106. lr=self.defaults["lr"],
  107. weight_decay=self.defaults["weight_decay"],
  108. eps=self.defaults["eps"],
  109. foreach=self.foreach,
  110. fused=self.fused,
  111. grad_scale=None,
  112. found_inf=None,
  113. )
  114. def step(self, gradients: List[Optional[Tensor]]):
  115. params = self.param_group["params"]
  116. params_with_grad = []
  117. grads = []
  118. exp_avgs = []
  119. exp_avg_sqs = []
  120. max_exp_avg_sqs = []
  121. state_steps: List[Tensor] = []
  122. has_complex = False
  123. if len(params) != len(gradients):
  124. raise ValueError(
  125. "the gradients passed in does not equal to the size of the parameters!"
  126. + f"Params length: {len(params)}. "
  127. + f"Gradients length: {len(gradients)}"
  128. )
  129. for param, gradient in zip(self.param_group["params"], gradients):
  130. if gradient is not None:
  131. has_complex |= torch.is_complex(param)
  132. params_with_grad.append(param)
  133. grads.append(gradient)
  134. # Lazy state initialization
  135. if param not in self.state:
  136. self.state[param] = {}
  137. state = self.state[param]
  138. state["step"] = torch.tensor(0.0)
  139. # Exponential moving average of gradient values
  140. state["exp_avg"] = torch.zeros_like(
  141. param, memory_format=torch.preserve_format
  142. )
  143. # Exponential moving average of squared gradient values
  144. state["exp_avg_sq"] = torch.zeros_like(
  145. param, memory_format=torch.preserve_format
  146. )
  147. if self.amsgrad:
  148. # Maintains max of all exp. moving avg. of sq. grad. values
  149. state["max_exp_avg_sq"] = torch.zeros_like(
  150. param, memory_format=torch.preserve_format
  151. )
  152. state = self.state[param]
  153. exp_avgs.append(state["exp_avg"])
  154. exp_avg_sqs.append(state["exp_avg_sq"])
  155. if self.amsgrad:
  156. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  157. state_steps.append(state["step"])
  158. with torch.no_grad():
  159. F.adam(
  160. params_with_grad,
  161. grads,
  162. exp_avgs,
  163. exp_avg_sqs,
  164. max_exp_avg_sqs,
  165. state_steps,
  166. amsgrad=self.amsgrad,
  167. has_complex=has_complex,
  168. maximize=self.maximize,
  169. beta1=self.defaults["beta1"],
  170. beta2=self.defaults["beta2"],
  171. lr=self.defaults["lr"],
  172. weight_decay=self.defaults["weight_decay"],
  173. eps=self.defaults["eps"],
  174. foreach=self.foreach,
  175. fused=self.fused,
  176. grad_scale=None,
  177. found_inf=None,
  178. )