rmsprop.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _maximize_doc,
  14. _use_grad_for_differentiable,
  15. _view_as_real,
  16. Optimizer,
  17. ParamsT,
  18. )
  19. __all__ = ["RMSprop", "rmsprop"]
  20. class RMSprop(Optimizer):
  21. def __init__(
  22. self,
  23. params: ParamsT,
  24. lr: float = 1e-2,
  25. alpha: float = 0.99,
  26. eps: float = 1e-8,
  27. weight_decay: float = 0,
  28. momentum: float = 0,
  29. centered=False,
  30. capturable=False,
  31. foreach: Optional[bool] = None,
  32. maximize: bool = False,
  33. differentiable: bool = False,
  34. ):
  35. if not 0.0 <= lr:
  36. raise ValueError(f"Invalid learning rate: {lr}")
  37. if not 0.0 <= eps:
  38. raise ValueError(f"Invalid epsilon value: {eps}")
  39. if not 0.0 <= momentum:
  40. raise ValueError(f"Invalid momentum value: {momentum}")
  41. if not 0.0 <= weight_decay:
  42. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  43. if not 0.0 <= alpha:
  44. raise ValueError(f"Invalid alpha value: {alpha}")
  45. defaults = dict(
  46. lr=lr,
  47. momentum=momentum,
  48. alpha=alpha,
  49. eps=eps,
  50. centered=centered,
  51. weight_decay=weight_decay,
  52. capturable=capturable,
  53. foreach=foreach,
  54. maximize=maximize,
  55. differentiable=differentiable,
  56. )
  57. super().__init__(params, defaults)
  58. def __setstate__(self, state):
  59. super().__setstate__(state)
  60. for group in self.param_groups:
  61. group.setdefault("momentum", 0)
  62. group.setdefault("centered", False)
  63. group.setdefault("foreach", None)
  64. group.setdefault("maximize", False)
  65. group.setdefault("differentiable", False)
  66. group.setdefault("capturable", False)
  67. for p in group["params"]:
  68. p_state = self.state.get(p, [])
  69. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  70. step_val = float(p_state["step"])
  71. p_state["step"] = (
  72. torch.tensor(
  73. step_val, dtype=_get_scalar_dtype(), device=p.device
  74. )
  75. if group["capturable"]
  76. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  77. )
  78. def _init_group(
  79. self,
  80. group,
  81. params_with_grad,
  82. grads,
  83. square_avgs,
  84. momentum_buffer_list,
  85. grad_avgs,
  86. state_steps,
  87. ):
  88. has_complex = False
  89. for p in group["params"]:
  90. if p.grad is None:
  91. continue
  92. has_complex |= torch.is_complex(p)
  93. params_with_grad.append(p)
  94. if p.grad.is_sparse:
  95. raise RuntimeError("RMSprop does not support sparse gradients")
  96. grads.append(p.grad)
  97. state = self.state[p]
  98. # State initialization
  99. if len(state) == 0:
  100. state["step"] = (
  101. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  102. if group["capturable"]
  103. else torch.zeros((), dtype=_get_scalar_dtype())
  104. )
  105. state["square_avg"] = torch.zeros_like(
  106. p, memory_format=torch.preserve_format
  107. )
  108. if group["momentum"] > 0:
  109. state["momentum_buffer"] = torch.zeros_like(
  110. p, memory_format=torch.preserve_format
  111. )
  112. if group["centered"]:
  113. state["grad_avg"] = torch.zeros_like(
  114. p, memory_format=torch.preserve_format
  115. )
  116. square_avgs.append(state["square_avg"])
  117. state_steps.append(state["step"])
  118. if group["momentum"] > 0:
  119. momentum_buffer_list.append(state["momentum_buffer"])
  120. if group["centered"]:
  121. grad_avgs.append(state["grad_avg"])
  122. return has_complex
  123. @_use_grad_for_differentiable
  124. def step(self, closure=None):
  125. """Performs a single optimization step.
  126. Args:
  127. closure (Callable, optional): A closure that reevaluates the model
  128. and returns the loss.
  129. """
  130. self._cuda_graph_capture_health_check()
  131. loss = None
  132. if closure is not None:
  133. with torch.enable_grad():
  134. loss = closure()
  135. for group in self.param_groups:
  136. params_with_grad: List[Tensor] = []
  137. grads: List[Tensor] = []
  138. square_avgs: List[Tensor] = []
  139. grad_avgs: List[Tensor] = []
  140. momentum_buffer_list: List[Tensor] = []
  141. state_steps: List[Tensor] = []
  142. has_complex = self._init_group(
  143. group,
  144. params_with_grad,
  145. grads,
  146. square_avgs,
  147. momentum_buffer_list,
  148. grad_avgs,
  149. state_steps,
  150. )
  151. rmsprop(
  152. params_with_grad,
  153. grads,
  154. square_avgs,
  155. grad_avgs,
  156. momentum_buffer_list,
  157. state_steps,
  158. lr=group["lr"],
  159. alpha=group["alpha"],
  160. eps=group["eps"],
  161. weight_decay=group["weight_decay"],
  162. momentum=group["momentum"],
  163. centered=group["centered"],
  164. foreach=group["foreach"],
  165. maximize=group["maximize"],
  166. differentiable=group["differentiable"],
  167. capturable=group["capturable"],
  168. has_complex=has_complex,
  169. )
  170. return loss
  171. RMSprop.__doc__ = (
  172. r"""Implements RMSprop algorithm.
  173. .. math::
  174. \begin{aligned}
  175. &\rule{110mm}{0.4pt} \\
  176. &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
  177. \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  178. &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
  179. &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
  180. \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
  181. &\rule{110mm}{0.4pt} \\
  182. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  183. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  184. &\hspace{5mm}if \: \lambda \neq 0 \\
  185. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  186. &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
  187. \hspace{8mm} \\
  188. &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
  189. &\hspace{5mm}if \: centered \\
  190. &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
  191. &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
  192. &\hspace{5mm}if \: \mu > 0 \\
  193. &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
  194. g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
  195. &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
  196. &\hspace{5mm} else \\
  197. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
  198. \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
  199. &\rule{110mm}{0.4pt} \\[-1.ex]
  200. &\bf{return} \: \theta_t \\[-1.ex]
  201. &\rule{110mm}{0.4pt} \\[-1.ex]
  202. \end{aligned}
  203. For further details regarding the algorithm we refer to
  204. `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
  205. and centered version `Generating Sequences
  206. With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
  207. The implementation here takes the square root of the gradient average before
  208. adding epsilon (note that TensorFlow interchanges these two operations). The effective
  209. learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
  210. is the scheduled learning rate and :math:`v` is the weighted moving average
  211. of the squared gradient.
  212. """
  213. + rf"""
  214. Args:
  215. params (iterable): iterable of parameters to optimize or dicts defining
  216. parameter groups
  217. lr (float, optional): learning rate (default: 1e-2)
  218. momentum (float, optional): momentum factor (default: 0)
  219. alpha (float, optional): smoothing constant (default: 0.99)
  220. eps (float, optional): term added to the denominator to improve
  221. numerical stability (default: 1e-8)
  222. centered (bool, optional) : if ``True``, compute the centered RMSProp,
  223. the gradient is normalized by an estimation of its variance
  224. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  225. {_foreach_doc}
  226. {_maximize_doc}
  227. {_capturable_doc}
  228. {_differentiable_doc}
  229. """
  230. )
  231. def _single_tensor_rmsprop(
  232. params: List[Tensor],
  233. grads: List[Tensor],
  234. square_avgs: List[Tensor],
  235. grad_avgs: List[Tensor],
  236. momentum_buffer_list: List[Tensor],
  237. state_steps: List[Tensor],
  238. *,
  239. lr: float,
  240. alpha: float,
  241. eps: float,
  242. weight_decay: float,
  243. momentum: float,
  244. centered: bool,
  245. maximize: bool,
  246. differentiable: bool,
  247. capturable: bool,
  248. has_complex: bool,
  249. ):
  250. for i, param in enumerate(params):
  251. step = state_steps[i]
  252. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  253. if not torch._utils.is_compiling() and capturable:
  254. capturable_supported_devices = _get_capturable_supported_devices()
  255. assert (
  256. param.device.type == step.device.type
  257. and param.device.type in capturable_supported_devices
  258. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  259. grad = grads[i]
  260. grad = grad if not maximize else -grad
  261. square_avg = square_avgs[i]
  262. step += 1
  263. if weight_decay != 0:
  264. grad = grad.add(param, alpha=weight_decay)
  265. is_complex_param = torch.is_complex(param)
  266. if is_complex_param:
  267. param = torch.view_as_real(param)
  268. grad = torch.view_as_real(grad)
  269. square_avg = torch.view_as_real(square_avg)
  270. square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
  271. if centered:
  272. grad_avg = grad_avgs[i]
  273. if is_complex_param:
  274. grad_avg = torch.view_as_real(grad_avg)
  275. grad_avg.lerp_(grad, 1 - alpha)
  276. avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
  277. else:
  278. avg = square_avg.sqrt()
  279. if differentiable:
  280. avg = avg.add(eps)
  281. else:
  282. avg = avg.add_(eps)
  283. if momentum > 0:
  284. buf = momentum_buffer_list[i]
  285. if is_complex_param:
  286. buf = torch.view_as_real(buf)
  287. buf.mul_(momentum).addcdiv_(grad, avg)
  288. param.add_(buf, alpha=-lr)
  289. else:
  290. param.addcdiv_(grad, avg, value=-lr)
  291. def _multi_tensor_rmsprop(
  292. params: List[Tensor],
  293. grads: List[Tensor],
  294. square_avgs: List[Tensor],
  295. grad_avgs: List[Tensor],
  296. momentum_buffer_list: List[Tensor],
  297. state_steps: List[Tensor],
  298. *,
  299. lr: float,
  300. alpha: float,
  301. eps: float,
  302. weight_decay: float,
  303. momentum: float,
  304. centered: bool,
  305. maximize: bool,
  306. differentiable: bool,
  307. capturable: bool,
  308. has_complex: bool,
  309. ):
  310. if len(params) == 0:
  311. return
  312. assert not differentiable, "_foreach ops don't support autograd"
  313. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  314. if not torch._utils.is_compiling() and capturable:
  315. capturable_supported_devices = _get_capturable_supported_devices()
  316. assert all(
  317. p.device.type == step.device.type
  318. and p.device.type in capturable_supported_devices
  319. for p, step in zip(params, state_steps)
  320. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  321. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  322. [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps]
  323. )
  324. for (
  325. (
  326. grouped_params,
  327. grouped_grads,
  328. grouped_square_avgs,
  329. grouped_grad_avgs,
  330. grouped_momentum_buffer_list,
  331. grouped_state_steps,
  332. )
  333. ), _ in grouped_tensors.values():
  334. if has_complex:
  335. state_and_grads = [grouped_grads, grouped_square_avgs]
  336. if momentum > 0:
  337. state_and_grads.append(grouped_momentum_buffer_list)
  338. if centered:
  339. state_and_grads.append(grouped_grad_avgs)
  340. _view_as_real(grouped_params, *state_and_grads)
  341. if maximize:
  342. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  343. # Update steps
  344. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  345. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  346. # wrapped it once now. The alpha is required to assure we go to the right overload.
  347. if grouped_state_steps[0].is_cpu:
  348. torch._foreach_add_(
  349. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  350. )
  351. else:
  352. torch._foreach_add_(grouped_state_steps, 1)
  353. if weight_decay != 0:
  354. # Re-use the intermediate memory (grouped_grads) already allocated for maximize
  355. if maximize:
  356. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  357. else:
  358. grouped_grads = torch._foreach_add( # type: ignore[assignment]
  359. grouped_grads, grouped_params, alpha=weight_decay
  360. )
  361. torch._foreach_mul_(grouped_square_avgs, alpha)
  362. torch._foreach_addcmul_(
  363. grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
  364. )
  365. if centered:
  366. torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
  367. avg = torch._foreach_addcmul(
  368. grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
  369. )
  370. torch._foreach_sqrt_(avg)
  371. torch._foreach_add_(avg, eps)
  372. else:
  373. avg = torch._foreach_sqrt(grouped_square_avgs)
  374. torch._foreach_add_(avg, eps)
  375. if momentum > 0:
  376. torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
  377. torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
  378. # If LR is a tensor, the else branch will internally call item()
  379. # which will cause silent incorrectness if we are capturing
  380. if capturable and isinstance(lr, torch.Tensor):
  381. momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
  382. torch._foreach_add_(grouped_params, momentum_lr)
  383. else:
  384. torch._foreach_add_(
  385. grouped_params, grouped_momentum_buffer_list, alpha=-lr
  386. )
  387. else:
  388. # If LR is a tensor, the else branch will internally call item()
  389. # which will cause silent incorrectness if we are capturing
  390. if capturable and isinstance(lr, torch.Tensor):
  391. torch._foreach_div_(avg, -lr)
  392. torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
  393. else:
  394. torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
  395. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
  396. def rmsprop(
  397. params: List[Tensor],
  398. grads: List[Tensor],
  399. square_avgs: List[Tensor],
  400. grad_avgs: List[Tensor],
  401. momentum_buffer_list: List[Tensor],
  402. state_steps: List[Tensor],
  403. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  404. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  405. foreach: Optional[bool] = None,
  406. maximize: bool = False,
  407. differentiable: bool = False,
  408. capturable: bool = False,
  409. has_complex: bool = False,
  410. *,
  411. lr: float,
  412. alpha: float,
  413. eps: float,
  414. weight_decay: float,
  415. momentum: float,
  416. centered: bool,
  417. ):
  418. r"""Functional API that performs rmsprop algorithm computation.
  419. See :class:`~torch.optim.RMSProp` for details.
  420. """
  421. # this check is slow during compilation, so we skip it
  422. # if it's strictly needed we can add this check back in dynamo
  423. if not torch._utils.is_compiling() and not all(
  424. isinstance(t, torch.Tensor) for t in state_steps
  425. ):
  426. raise RuntimeError(
  427. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  428. )
  429. if foreach is None:
  430. _, foreach = _default_to_fused_or_foreach(
  431. params, differentiable, use_fused=False
  432. )
  433. if foreach and torch.jit.is_scripting():
  434. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  435. if foreach and not torch.jit.is_scripting():
  436. func = _multi_tensor_rmsprop
  437. else:
  438. func = _single_tensor_rmsprop
  439. func(
  440. params,
  441. grads,
  442. square_avgs,
  443. grad_avgs,
  444. momentum_buffer_list,
  445. state_steps,
  446. lr=lr,
  447. alpha=alpha,
  448. eps=eps,
  449. weight_decay=weight_decay,
  450. momentum=momentum,
  451. centered=centered,
  452. maximize=maximize,
  453. capturable=capturable,
  454. differentiable=differentiable,
  455. has_complex=has_complex,
  456. )