adam.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
  6. from .optimizer import (
  7. _capturable_doc,
  8. _default_to_fused_or_foreach,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _dispatch_sqrt,
  12. _foreach_doc,
  13. _fused_doc,
  14. _get_capturable_supported_devices,
  15. _get_scalar_dtype,
  16. _get_value,
  17. _maximize_doc,
  18. _stack_if_compiling,
  19. _use_grad_for_differentiable,
  20. _view_as_real,
  21. DeviceDict,
  22. Optimizer,
  23. ParamsT,
  24. )
  25. __all__ = ["Adam", "adam"]
  26. class Adam(Optimizer):
  27. def __init__(
  28. self,
  29. params: ParamsT,
  30. lr: Union[float, Tensor] = 1e-3,
  31. betas: Tuple[float, float] = (0.9, 0.999),
  32. eps: float = 1e-8,
  33. weight_decay: float = 0,
  34. amsgrad: bool = False,
  35. *,
  36. foreach: Optional[bool] = None,
  37. maximize: bool = False,
  38. capturable: bool = False,
  39. differentiable: bool = False,
  40. fused: Optional[bool] = None,
  41. ):
  42. if not 0.0 <= lr:
  43. raise ValueError(f"Invalid learning rate: {lr}")
  44. if isinstance(lr, Tensor) and foreach and not capturable:
  45. raise ValueError(
  46. "lr as a Tensor is not supported for capturable=False and foreach=True"
  47. )
  48. if not 0.0 <= eps:
  49. raise ValueError(f"Invalid epsilon value: {eps}")
  50. if not 0.0 <= betas[0] < 1.0:
  51. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  52. if not 0.0 <= betas[1] < 1.0:
  53. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  54. if not 0.0 <= weight_decay:
  55. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  56. defaults = dict(
  57. lr=lr,
  58. betas=betas,
  59. eps=eps,
  60. weight_decay=weight_decay,
  61. amsgrad=amsgrad,
  62. maximize=maximize,
  63. foreach=foreach,
  64. capturable=capturable,
  65. differentiable=differentiable,
  66. fused=fused,
  67. )
  68. super().__init__(params, defaults)
  69. if fused:
  70. if differentiable:
  71. raise RuntimeError("`fused` does not support `differentiable`")
  72. self._step_supports_amp_scaling = True
  73. # TODO(crcrpar): [low prec params & their higher prec copy]
  74. # Support AMP with FP16/BF16 model params which would need
  75. # higher prec copy of params to do update math in higher prec to
  76. # alleviate the loss of information.
  77. fused_supported_devices = _get_fused_kernels_supported_devices()
  78. if not all(
  79. p.device.type in fused_supported_devices and torch.is_floating_point(p)
  80. for pg in self.param_groups
  81. for p in pg["params"]
  82. ):
  83. raise RuntimeError(
  84. "`fused=True` requires all the params to be floating point Tensors of "
  85. f"supported devices: {fused_supported_devices}."
  86. )
  87. if foreach:
  88. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  89. def __setstate__(self, state):
  90. super().__setstate__(state)
  91. for group in self.param_groups:
  92. group.setdefault("amsgrad", False)
  93. group.setdefault("maximize", False)
  94. group.setdefault("foreach", None)
  95. group.setdefault("capturable", False)
  96. group.setdefault("differentiable", False)
  97. fused = group.setdefault("fused", None)
  98. for p in group["params"]:
  99. p_state = self.state.get(p, [])
  100. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  101. step_val = float(p_state["step"])
  102. p_state["step"] = (
  103. torch.tensor(
  104. step_val,
  105. dtype=_get_scalar_dtype(is_fused=fused),
  106. device=p.device,
  107. )
  108. if group["capturable"] or group["fused"]
  109. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  110. )
  111. def _init_group(
  112. self,
  113. group,
  114. params_with_grad,
  115. grads,
  116. exp_avgs,
  117. exp_avg_sqs,
  118. max_exp_avg_sqs,
  119. state_steps,
  120. ):
  121. has_complex = False
  122. for p in group["params"]:
  123. if p.grad is not None:
  124. has_complex |= torch.is_complex(p)
  125. params_with_grad.append(p)
  126. if p.grad.is_sparse:
  127. raise RuntimeError(
  128. "Adam does not support sparse gradients, please consider SparseAdam instead"
  129. )
  130. grads.append(p.grad)
  131. state = self.state[p]
  132. # Lazy state initialization
  133. if len(state) == 0:
  134. # note(crcrpar): [special device hosting for step]
  135. # Deliberately host `step` on CPU if both capturable and fused are off.
  136. # This is because kernel launches are costly on CUDA and XLA.
  137. state["step"] = (
  138. torch.zeros(
  139. (),
  140. dtype=_get_scalar_dtype(is_fused=group["fused"]),
  141. device=p.device,
  142. )
  143. if group["capturable"] or group["fused"]
  144. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  145. )
  146. # Exponential moving average of gradient values
  147. state["exp_avg"] = torch.zeros_like(
  148. p, memory_format=torch.preserve_format
  149. )
  150. # Exponential moving average of squared gradient values
  151. state["exp_avg_sq"] = torch.zeros_like(
  152. p, memory_format=torch.preserve_format
  153. )
  154. if group["amsgrad"]:
  155. # Maintains max of all exp. moving avg. of sq. grad. values
  156. state["max_exp_avg_sq"] = torch.zeros_like(
  157. p, memory_format=torch.preserve_format
  158. )
  159. exp_avgs.append(state["exp_avg"])
  160. exp_avg_sqs.append(state["exp_avg_sq"])
  161. if group["amsgrad"]:
  162. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  163. if group["differentiable"] and state["step"].requires_grad:
  164. raise RuntimeError(
  165. "`requires_grad` is not supported for `step` in differentiable mode"
  166. )
  167. # Foreach without capturable does not support a tensor lr
  168. if (
  169. group["foreach"]
  170. and torch.is_tensor(group["lr"])
  171. and not group["capturable"]
  172. ):
  173. raise RuntimeError(
  174. "lr as a Tensor is not supported for capturable=False and foreach=True"
  175. )
  176. state_steps.append(state["step"])
  177. return has_complex
  178. @_use_grad_for_differentiable
  179. def step(self, closure=None):
  180. """Perform a single optimization step.
  181. Args:
  182. closure (Callable, optional): A closure that reevaluates the model
  183. and returns the loss.
  184. """
  185. self._cuda_graph_capture_health_check()
  186. loss = None
  187. if closure is not None:
  188. with torch.enable_grad():
  189. loss = closure()
  190. for group in self.param_groups:
  191. params_with_grad: List[Tensor] = []
  192. grads: List[Tensor] = []
  193. exp_avgs: List[Tensor] = []
  194. exp_avg_sqs: List[Tensor] = []
  195. max_exp_avg_sqs: List[Tensor] = []
  196. state_steps: List[Tensor] = []
  197. beta1, beta2 = group["betas"]
  198. has_complex = self._init_group(
  199. group,
  200. params_with_grad,
  201. grads,
  202. exp_avgs,
  203. exp_avg_sqs,
  204. max_exp_avg_sqs,
  205. state_steps,
  206. )
  207. adam(
  208. params_with_grad,
  209. grads,
  210. exp_avgs,
  211. exp_avg_sqs,
  212. max_exp_avg_sqs,
  213. state_steps,
  214. amsgrad=group["amsgrad"],
  215. has_complex=has_complex,
  216. beta1=beta1,
  217. beta2=beta2,
  218. lr=group["lr"],
  219. weight_decay=group["weight_decay"],
  220. eps=group["eps"],
  221. maximize=group["maximize"],
  222. foreach=group["foreach"],
  223. capturable=group["capturable"],
  224. differentiable=group["differentiable"],
  225. fused=group["fused"],
  226. grad_scale=getattr(self, "grad_scale", None),
  227. found_inf=getattr(self, "found_inf", None),
  228. )
  229. return loss
  230. Adam.__doc__ = (
  231. r"""Implements Adam algorithm.
  232. .. math::
  233. \begin{aligned}
  234. &\rule{110mm}{0.4pt} \\
  235. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  236. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
  237. &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
  238. \:\textit{maximize} \\
  239. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  240. v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
  241. &\rule{110mm}{0.4pt} \\
  242. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  243. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  244. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  245. &\hspace{5mm}\textbf{else} \\
  246. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  247. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  248. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  249. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  250. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  251. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  252. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  253. &\hspace{5mm}\textbf{if} \: amsgrad \\
  254. &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
  255. \widehat{v_t}) \\
  256. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  257. \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
  258. &\hspace{5mm}\textbf{else} \\
  259. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  260. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  261. &\rule{110mm}{0.4pt} \\[-1.ex]
  262. &\bf{return} \: \theta_t \\[-1.ex]
  263. &\rule{110mm}{0.4pt} \\[-1.ex]
  264. \end{aligned}
  265. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  266. """
  267. + rf"""
  268. Args:
  269. params (iterable): iterable of parameters to optimize or dicts defining
  270. parameter groups
  271. lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
  272. is not yet supported for all our implementations. Please use a float
  273. LR if you are not also specifying fused=True or capturable=True.
  274. betas (Tuple[float, float], optional): coefficients used for computing
  275. running averages of gradient and its square (default: (0.9, 0.999))
  276. eps (float, optional): term added to the denominator to improve
  277. numerical stability (default: 1e-8)
  278. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  279. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  280. algorithm from the paper `On the Convergence of Adam and Beyond`_
  281. (default: False)
  282. {_foreach_doc}
  283. {_maximize_doc}
  284. {_capturable_doc}
  285. {_differentiable_doc}
  286. {_fused_doc}
  287. .. _Adam\: A Method for Stochastic Optimization:
  288. https://arxiv.org/abs/1412.6980
  289. .. _On the Convergence of Adam and Beyond:
  290. https://openreview.net/forum?id=ryQu7f-RZ
  291. """
  292. )
  293. def _single_tensor_adam(
  294. params: List[Tensor],
  295. grads: List[Tensor],
  296. exp_avgs: List[Tensor],
  297. exp_avg_sqs: List[Tensor],
  298. max_exp_avg_sqs: List[Tensor],
  299. state_steps: List[Tensor],
  300. grad_scale: Optional[Tensor],
  301. found_inf: Optional[Tensor],
  302. *,
  303. amsgrad: bool,
  304. has_complex: bool,
  305. beta1: float,
  306. beta2: float,
  307. lr: Union[float, Tensor],
  308. weight_decay: float,
  309. eps: float,
  310. maximize: bool,
  311. capturable: bool,
  312. differentiable: bool,
  313. ):
  314. assert grad_scale is None and found_inf is None
  315. if torch.jit.is_scripting():
  316. # this assert is due to JIT being dumb and not realizing that the ops below
  317. # have overloads to handle both float and Tensor lrs, so we just assert it's
  318. # a float since most people using JIT are using floats
  319. assert isinstance(lr, float)
  320. for i, param in enumerate(params):
  321. grad = grads[i] if not maximize else -grads[i]
  322. exp_avg = exp_avgs[i]
  323. exp_avg_sq = exp_avg_sqs[i]
  324. step_t = state_steps[i]
  325. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  326. if not torch._utils.is_compiling() and capturable:
  327. capturable_supported_devices = _get_capturable_supported_devices()
  328. assert (
  329. param.device.type == step_t.device.type
  330. and param.device.type in capturable_supported_devices
  331. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  332. # update step
  333. step_t += 1
  334. if weight_decay != 0:
  335. grad = grad.add(param, alpha=weight_decay)
  336. if torch.is_complex(param):
  337. grad = torch.view_as_real(grad)
  338. exp_avg = torch.view_as_real(exp_avg)
  339. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  340. if amsgrad:
  341. max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
  342. param = torch.view_as_real(param)
  343. # Decay the first and second moment running average coefficient
  344. exp_avg.lerp_(grad, 1 - beta1)
  345. exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
  346. if capturable or differentiable:
  347. step = step_t
  348. bias_correction1 = 1 - beta1**step
  349. bias_correction2 = 1 - beta2**step
  350. step_size = lr / bias_correction1
  351. step_size_neg = step_size.neg()
  352. bias_correction2_sqrt = bias_correction2.sqrt()
  353. if amsgrad:
  354. # Maintains the maximum of all 2nd moment running avg. till now
  355. if differentiable:
  356. max_exp_avg_sq = max_exp_avg_sqs[i].clone()
  357. else:
  358. max_exp_avg_sq = max_exp_avg_sqs[i]
  359. max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
  360. # Uses the max. for normalizing running avg. of gradient
  361. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  362. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  363. denom = (
  364. max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
  365. ).add_(eps / step_size_neg)
  366. else:
  367. denom = (
  368. exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
  369. ).add_(eps / step_size_neg)
  370. param.addcdiv_(exp_avg, denom)
  371. else:
  372. step = _get_value(step_t)
  373. bias_correction1 = 1 - beta1**step
  374. bias_correction2 = 1 - beta2**step
  375. step_size = lr / bias_correction1
  376. bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
  377. if amsgrad:
  378. # Maintains the maximum of all 2nd moment running avg. till now
  379. torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
  380. # Use the max. for normalizing running avg. of gradient
  381. denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
  382. else:
  383. denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
  384. param.addcdiv_(exp_avg, denom, value=-step_size)
  385. # Lastly, switch back to complex view
  386. if amsgrad and torch.is_complex(params[i]):
  387. max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
  388. def _multi_tensor_adam(
  389. params: List[Tensor],
  390. grads: List[Tensor],
  391. exp_avgs: List[Tensor],
  392. exp_avg_sqs: List[Tensor],
  393. max_exp_avg_sqs: List[Tensor],
  394. state_steps: List[Tensor],
  395. grad_scale: Optional[Tensor],
  396. found_inf: Optional[Tensor],
  397. *,
  398. amsgrad: bool,
  399. has_complex: bool,
  400. beta1: float,
  401. beta2: float,
  402. lr: Union[float, Tensor],
  403. weight_decay: float,
  404. eps: float,
  405. maximize: bool,
  406. capturable: bool,
  407. differentiable: bool,
  408. ):
  409. if len(params) == 0:
  410. return
  411. if isinstance(lr, Tensor) and not capturable:
  412. raise RuntimeError(
  413. "lr as a Tensor is not supported for capturable=False and foreach=True"
  414. )
  415. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  416. if not torch._utils.is_compiling() and capturable:
  417. capturable_supported_devices = _get_capturable_supported_devices(
  418. supports_xla=False
  419. )
  420. assert all(
  421. p.device.type == step.device.type
  422. and p.device.type in capturable_supported_devices
  423. for p, step in zip(params, state_steps)
  424. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  425. assert grad_scale is None and found_inf is None
  426. assert not differentiable, "_foreach ops don't support autograd"
  427. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  428. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]
  429. )
  430. for (
  431. device_params,
  432. device_grads,
  433. device_exp_avgs,
  434. device_exp_avg_sqs,
  435. device_max_exp_avg_sqs,
  436. device_state_steps,
  437. ), _ in grouped_tensors.values():
  438. # Handle complex parameters
  439. if has_complex:
  440. if amsgrad:
  441. _view_as_real(
  442. device_params,
  443. device_grads,
  444. device_exp_avgs,
  445. device_exp_avg_sqs,
  446. device_max_exp_avg_sqs,
  447. )
  448. else:
  449. _view_as_real(
  450. device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
  451. )
  452. if maximize:
  453. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  454. # Update steps
  455. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  456. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  457. # wrapped it once now. The alpha is required to assure we go to the right overload.
  458. if device_state_steps[0].is_cpu:
  459. torch._foreach_add_(
  460. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  461. )
  462. else:
  463. torch._foreach_add_(device_state_steps, 1)
  464. if weight_decay != 0:
  465. # Re-use the intermediate memory (device_grads) already allocated for maximize
  466. if maximize:
  467. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  468. else:
  469. device_grads = torch._foreach_add( # type: ignore[assignment]
  470. device_grads, device_params, alpha=weight_decay
  471. )
  472. # Decay the first and second moment running average coefficient
  473. torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
  474. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  475. torch._foreach_addcmul_(
  476. device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
  477. )
  478. # Delete the local intermediate since it won't be used anymore to save on peak memory
  479. del device_grads
  480. bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
  481. bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
  482. bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
  483. if capturable:
  484. bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
  485. bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
  486. # foreach_sub doesn't allow a scalar as the first arg
  487. torch._foreach_sub_(bias_correction1, 1)
  488. torch._foreach_sub_(bias_correction2, 1)
  489. # we do not negate bias_correction1 as it'll need to be negated later anyway
  490. torch._foreach_neg_(bias_correction2)
  491. # foreach_div doesn't allow a scalar as the first arg
  492. torch._foreach_div_(bias_correction1, lr)
  493. torch._foreach_reciprocal_(bias_correction1)
  494. torch._foreach_sqrt_(bias_correction2)
  495. # Re-assign for clarity as we maintain minimal intermediates: we'll have
  496. # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
  497. # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
  498. step_size = bias_correction1
  499. bias_correction2_sqrt = bias_correction2
  500. if amsgrad:
  501. # Maintains the maximum of all 2nd moment running avg. till now
  502. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
  503. # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
  504. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  505. else:
  506. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  507. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  508. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  509. torch._foreach_div_(exp_avg_sq_sqrt, step_size)
  510. # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
  511. torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
  512. else:
  513. bias_correction1 = [
  514. 1 - beta1 ** _get_value(step) for step in device_state_steps
  515. ]
  516. bias_correction2 = [
  517. 1 - beta2 ** _get_value(step) for step in device_state_steps
  518. ]
  519. step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
  520. bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] # type: ignore[arg-type]
  521. if amsgrad:
  522. # Maintains the maximum of all 2nd moment running avg. till now
  523. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  524. # Use the max. for normalizing running avg. of gradient
  525. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  526. else:
  527. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  528. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  529. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  530. torch._foreach_addcdiv_(
  531. device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
  532. )
  533. def _fused_adam(
  534. params: List[Tensor],
  535. grads: List[Tensor],
  536. exp_avgs: List[Tensor],
  537. exp_avg_sqs: List[Tensor],
  538. max_exp_avg_sqs: List[Tensor],
  539. state_steps: List[Tensor],
  540. grad_scale: Optional[Tensor],
  541. found_inf: Optional[Tensor],
  542. *,
  543. amsgrad: bool,
  544. has_complex: bool, # Needed for consistency.
  545. beta1: float,
  546. beta2: float,
  547. lr: Union[float, Tensor],
  548. weight_decay: float,
  549. eps: float,
  550. maximize: bool,
  551. capturable: bool, # Needed for consistency.
  552. differentiable: bool,
  553. ) -> None:
  554. if not params:
  555. return
  556. if differentiable:
  557. raise RuntimeError("Adam with fused=True does not support differentiable=True")
  558. grad_scale_dict: DeviceDict = (
  559. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  560. )
  561. found_inf_dict: DeviceDict = (
  562. {found_inf.device: found_inf} if found_inf is not None else {}
  563. )
  564. # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
  565. # treating it as a scalar.
  566. lr_dict: Optional[DeviceDict] = (
  567. {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
  568. )
  569. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  570. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]
  571. )
  572. for (device, _), (
  573. (
  574. device_params,
  575. device_grads,
  576. device_exp_avgs,
  577. device_exp_avg_sqs,
  578. device_max_exp_avg_sqs,
  579. device_state_steps,
  580. ),
  581. _,
  582. ) in grouped_tensors.items():
  583. device_grad_scale, device_found_inf = None, None
  584. if grad_scale is not None:
  585. device_grad_scale = grad_scale_dict.setdefault(
  586. device, grad_scale.to(device, non_blocking=True)
  587. )
  588. if found_inf is not None:
  589. device_found_inf = found_inf_dict.setdefault(
  590. device, found_inf.to(device, non_blocking=True)
  591. )
  592. if lr_dict is not None and device not in lr_dict:
  593. lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
  594. lr = lr_dict[device]
  595. torch._foreach_add_(device_state_steps, 1)
  596. torch._fused_adam_(
  597. device_params,
  598. device_grads,
  599. device_exp_avgs,
  600. device_exp_avg_sqs,
  601. device_max_exp_avg_sqs,
  602. device_state_steps,
  603. amsgrad=amsgrad,
  604. lr=lr,
  605. beta1=beta1,
  606. beta2=beta2,
  607. weight_decay=weight_decay,
  608. eps=eps,
  609. maximize=maximize,
  610. grad_scale=device_grad_scale,
  611. found_inf=device_found_inf,
  612. )
  613. if device_found_inf is not None:
  614. torch._foreach_sub_(
  615. device_state_steps, [device_found_inf] * len(device_state_steps)
  616. )
  617. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
  618. def adam(
  619. params: List[Tensor],
  620. grads: List[Tensor],
  621. exp_avgs: List[Tensor],
  622. exp_avg_sqs: List[Tensor],
  623. max_exp_avg_sqs: List[Tensor],
  624. state_steps: List[Tensor],
  625. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  626. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  627. foreach: Optional[bool] = None,
  628. capturable: bool = False,
  629. differentiable: bool = False,
  630. fused: Optional[bool] = None,
  631. grad_scale: Optional[Tensor] = None,
  632. found_inf: Optional[Tensor] = None,
  633. has_complex: bool = False,
  634. *,
  635. amsgrad: bool,
  636. beta1: float,
  637. beta2: float,
  638. lr: Union[float, Tensor],
  639. weight_decay: float,
  640. eps: float,
  641. maximize: bool,
  642. ):
  643. r"""Functional API that performs Adam algorithm computation.
  644. See :class:`~torch.optim.Adam` for details.
  645. """
  646. # Respect when the user inputs False/True for foreach or fused. We only want to change
  647. # the default when neither have been user-specified. Note that we default to foreach
  648. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  649. # bake-in time before making it the default, even if it is typically faster.
  650. if fused is None and foreach is None:
  651. _, foreach = _default_to_fused_or_foreach(
  652. params, differentiable, use_fused=False
  653. )
  654. # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
  655. if foreach and isinstance(lr, Tensor) and not capturable:
  656. foreach = False
  657. if fused is None:
  658. fused = False
  659. if foreach is None:
  660. foreach = False
  661. # this check is slow during compilation, so we skip it
  662. # if it's strictly needed we can add this check back in dynamo
  663. if not torch._utils.is_compiling() and not all(
  664. isinstance(t, torch.Tensor) for t in state_steps
  665. ):
  666. raise RuntimeError(
  667. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  668. )
  669. if foreach and torch.jit.is_scripting():
  670. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  671. if fused and torch.jit.is_scripting():
  672. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  673. if fused and not torch.jit.is_scripting():
  674. func = _fused_adam
  675. elif foreach and not torch.jit.is_scripting():
  676. func = _multi_tensor_adam
  677. else:
  678. func = _single_tensor_adam
  679. func(
  680. params,
  681. grads,
  682. exp_avgs,
  683. exp_avg_sqs,
  684. max_exp_avg_sqs,
  685. state_steps,
  686. amsgrad=amsgrad,
  687. has_complex=has_complex,
  688. beta1=beta1,
  689. beta2=beta2,
  690. lr=lr,
  691. weight_decay=weight_decay,
  692. eps=eps,
  693. maximize=maximize,
  694. capturable=capturable,
  695. differentiable=differentiable,
  696. grad_scale=grad_scale,
  697. found_inf=found_inf,
  698. )