sgd.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional
  3. import torch
  4. from torch import Tensor
  5. from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
  6. from .optimizer import (
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _foreach_doc,
  10. _fused_doc,
  11. _maximize_doc,
  12. _use_grad_for_differentiable,
  13. DeviceDict,
  14. Optimizer,
  15. )
  16. __all__ = ["SGD", "sgd"]
  17. class SGD(Optimizer):
  18. def __init__(
  19. self,
  20. params,
  21. lr: float = 1e-3,
  22. momentum: float = 0,
  23. dampening: float = 0,
  24. weight_decay: float = 0,
  25. nesterov=False,
  26. *,
  27. maximize: bool = False,
  28. foreach: Optional[bool] = None,
  29. differentiable: bool = False,
  30. fused: Optional[bool] = None,
  31. ):
  32. if lr < 0.0:
  33. raise ValueError(f"Invalid learning rate: {lr}")
  34. if momentum < 0.0:
  35. raise ValueError(f"Invalid momentum value: {momentum}")
  36. if weight_decay < 0.0:
  37. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  38. defaults = dict(
  39. lr=lr,
  40. momentum=momentum,
  41. dampening=dampening,
  42. weight_decay=weight_decay,
  43. nesterov=nesterov,
  44. maximize=maximize,
  45. foreach=foreach,
  46. differentiable=differentiable,
  47. fused=fused,
  48. )
  49. if nesterov and (momentum <= 0 or dampening != 0):
  50. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  51. super().__init__(params, defaults)
  52. if fused:
  53. self._step_supports_amp_scaling = True
  54. fused_supported_devices = _get_fused_kernels_supported_devices()
  55. if not all(
  56. p.device.type in fused_supported_devices and torch.is_floating_point(p)
  57. for pg in self.param_groups
  58. for p in pg["params"]
  59. ):
  60. raise RuntimeError(
  61. "`fused=True` requires all the params to be floating point Tensors of "
  62. f"supported devices: {fused_supported_devices}."
  63. )
  64. if differentiable:
  65. raise RuntimeError("`fused` does not support `differentiable`")
  66. if foreach:
  67. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  68. def __setstate__(self, state):
  69. super().__setstate__(state)
  70. for group in self.param_groups:
  71. group.setdefault("nesterov", False)
  72. group.setdefault("maximize", False)
  73. group.setdefault("foreach", None)
  74. group.setdefault("differentiable", False)
  75. group.setdefault("fused", False)
  76. def _init_group(self, group, params, grads, momentum_buffer_list):
  77. has_sparse_grad = False
  78. for p in group["params"]:
  79. if p.grad is not None:
  80. params.append(p)
  81. grads.append(p.grad)
  82. if p.grad.is_sparse:
  83. has_sparse_grad = True
  84. if group["momentum"] != 0:
  85. state = self.state[p]
  86. momentum_buffer_list.append(state.get("momentum_buffer"))
  87. return has_sparse_grad
  88. @_use_grad_for_differentiable
  89. def step(self, closure=None):
  90. """Performs a single optimization step.
  91. Args:
  92. closure (Callable, optional): A closure that reevaluates the model
  93. and returns the loss.
  94. """
  95. loss = None
  96. if closure is not None:
  97. with torch.enable_grad():
  98. loss = closure()
  99. for group in self.param_groups:
  100. params: List[Tensor] = []
  101. grads: List[Tensor] = []
  102. momentum_buffer_list: List[Optional[Tensor]] = []
  103. has_sparse_grad = self._init_group(
  104. group, params, grads, momentum_buffer_list
  105. )
  106. sgd(
  107. params,
  108. grads,
  109. momentum_buffer_list,
  110. weight_decay=group["weight_decay"],
  111. momentum=group["momentum"],
  112. lr=group["lr"],
  113. dampening=group["dampening"],
  114. nesterov=group["nesterov"],
  115. maximize=group["maximize"],
  116. has_sparse_grad=has_sparse_grad,
  117. foreach=group["foreach"],
  118. fused=group["fused"],
  119. grad_scale=getattr(self, "grad_scale", None),
  120. found_inf=getattr(self, "found_inf", None),
  121. )
  122. if group["momentum"] != 0:
  123. # update momentum_buffers in state
  124. for p, momentum_buffer in zip(params, momentum_buffer_list):
  125. state = self.state[p]
  126. state["momentum_buffer"] = momentum_buffer
  127. return loss
  128. SGD.__doc__ = (
  129. r"""Implements stochastic gradient descent (optionally with momentum).
  130. .. math::
  131. \begin{aligned}
  132. &\rule{110mm}{0.4pt} \\
  133. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  134. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  135. &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
  136. \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
  137. &\rule{110mm}{0.4pt} \\
  138. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  139. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  140. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  141. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  142. &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
  143. &\hspace{10mm}\textbf{if} \: t > 1 \\
  144. &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
  145. &\hspace{10mm}\textbf{else} \\
  146. &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
  147. &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
  148. &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
  149. &\hspace{10mm}\textbf{else} \\[-1.ex]
  150. &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
  151. &\hspace{5mm}\textbf{if} \: \textit{maximize} \\
  152. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
  153. &\hspace{5mm}\textbf{else} \\[-1.ex]
  154. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
  155. &\rule{110mm}{0.4pt} \\[-1.ex]
  156. &\bf{return} \: \theta_t \\[-1.ex]
  157. &\rule{110mm}{0.4pt} \\[-1.ex]
  158. \end{aligned}
  159. Nesterov momentum is based on the formula from
  160. `On the importance of initialization and momentum in deep learning`__.
  161. """
  162. + rf"""
  163. Args:
  164. params (iterable): iterable of parameters to optimize or dicts defining
  165. parameter groups
  166. lr (float, optional): learning rate (default: 1e-3)
  167. momentum (float, optional): momentum factor (default: 0)
  168. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  169. dampening (float, optional): dampening for momentum (default: 0)
  170. nesterov (bool, optional): enables Nesterov momentum (default: False)
  171. {_maximize_doc}
  172. {_foreach_doc}
  173. {_differentiable_doc}
  174. {_fused_doc}
  175. """
  176. + r"""
  177. Example:
  178. >>> # xdoctest: +SKIP
  179. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  180. >>> optimizer.zero_grad()
  181. >>> loss_fn(model(input), target).backward()
  182. >>> optimizer.step()
  183. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  184. .. note::
  185. The implementation of SGD with Momentum/Nesterov subtly differs from
  186. Sutskever et al. and implementations in some other frameworks.
  187. Considering the specific case of Momentum, the update can be written as
  188. .. math::
  189. \begin{aligned}
  190. v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
  191. p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
  192. \end{aligned}
  193. where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
  194. parameters, gradient, velocity, and momentum respectively.
  195. This is in contrast to Sutskever et al. and
  196. other frameworks which employ an update of the form
  197. .. math::
  198. \begin{aligned}
  199. v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
  200. p_{t+1} & = p_{t} - v_{t+1}.
  201. \end{aligned}
  202. The Nesterov version is analogously modified.
  203. Moreover, the initial value of the momentum buffer is set to the
  204. gradient value at the first step. This is in contrast to some other
  205. frameworks that initialize it to all zeros.
  206. """
  207. )
  208. def sgd(
  209. params: List[Tensor],
  210. d_p_list: List[Tensor],
  211. momentum_buffer_list: List[Optional[Tensor]],
  212. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  213. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  214. has_sparse_grad: bool = False,
  215. foreach: Optional[bool] = None,
  216. fused: Optional[bool] = None,
  217. grad_scale: Optional[Tensor] = None,
  218. found_inf: Optional[Tensor] = None,
  219. *,
  220. weight_decay: float,
  221. momentum: float,
  222. lr: float,
  223. dampening: float,
  224. nesterov: bool,
  225. maximize: bool,
  226. ):
  227. r"""Functional API that performs SGD algorithm computation.
  228. See :class:`~torch.optim.SGD` for details.
  229. """
  230. # Respect when the user inputs False/True for foreach or fused. We only want to change
  231. # the default when neither have been user-specified. Note that we default to foreach
  232. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  233. # bake-in time before making it the default, even if it is typically faster.
  234. if foreach is None and fused is None:
  235. # why must we be explicit about an if statement for torch.jit.is_scripting here?
  236. # because JIT can't handle Optionals nor fancy conditionals when scripting
  237. if not torch.jit.is_scripting():
  238. fused, foreach = _default_to_fused_or_foreach(
  239. params, differentiable=False, use_fused=False
  240. )
  241. else:
  242. foreach = False
  243. fused = False
  244. if foreach is None:
  245. foreach = False
  246. if fused is None:
  247. fused = False
  248. if foreach and torch.jit.is_scripting():
  249. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  250. if fused and torch.jit.is_scripting():
  251. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  252. if foreach and not torch.jit.is_scripting():
  253. func = _multi_tensor_sgd
  254. elif fused and not torch.jit.is_scripting():
  255. func = _fused_sgd
  256. else:
  257. func = _single_tensor_sgd
  258. func(
  259. params,
  260. d_p_list,
  261. momentum_buffer_list,
  262. weight_decay=weight_decay,
  263. momentum=momentum,
  264. lr=lr,
  265. dampening=dampening,
  266. nesterov=nesterov,
  267. has_sparse_grad=has_sparse_grad,
  268. maximize=maximize,
  269. grad_scale=grad_scale,
  270. found_inf=found_inf,
  271. )
  272. def _single_tensor_sgd(
  273. params: List[Tensor],
  274. grads: List[Tensor],
  275. momentum_buffer_list: List[Optional[Tensor]],
  276. grad_scale: Optional[Tensor],
  277. found_inf: Optional[Tensor],
  278. *,
  279. weight_decay: float,
  280. momentum: float,
  281. lr: float,
  282. dampening: float,
  283. nesterov: bool,
  284. maximize: bool,
  285. has_sparse_grad: bool,
  286. ):
  287. assert grad_scale is None and found_inf is None
  288. for i, param in enumerate(params):
  289. grad = grads[i] if not maximize else -grads[i]
  290. if weight_decay != 0:
  291. grad = grad.add(param, alpha=weight_decay)
  292. if momentum != 0:
  293. buf = momentum_buffer_list[i]
  294. if buf is None:
  295. buf = torch.clone(grad).detach()
  296. momentum_buffer_list[i] = buf
  297. else:
  298. buf.mul_(momentum).add_(grad, alpha=1 - dampening)
  299. if nesterov:
  300. grad = grad.add(buf, alpha=momentum)
  301. else:
  302. grad = buf
  303. param.add_(grad, alpha=-lr)
  304. def _multi_tensor_sgd(
  305. params: List[Tensor],
  306. grads: List[Tensor],
  307. momentum_buffer_list: List[Optional[Tensor]],
  308. grad_scale: Optional[Tensor],
  309. found_inf: Optional[Tensor],
  310. *,
  311. weight_decay: float,
  312. momentum: float,
  313. lr: float,
  314. dampening: float,
  315. nesterov: bool,
  316. maximize: bool,
  317. has_sparse_grad: bool,
  318. ):
  319. assert grad_scale is None and found_inf is None
  320. if len(params) == 0:
  321. return
  322. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  323. [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
  324. )
  325. for (
  326. device_params,
  327. device_grads,
  328. device_momentum_buffer_list,
  329. ), indices in grouped_tensors.values():
  330. device_has_sparse_grad = has_sparse_grad and any(
  331. grad.is_sparse for grad in device_grads
  332. )
  333. if maximize:
  334. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  335. if weight_decay != 0:
  336. # Re-use the intermediate memory (device_grads) already allocated for maximize
  337. if maximize:
  338. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  339. else:
  340. device_grads = torch._foreach_add( # type: ignore[assignment]
  341. device_grads, device_params, alpha=weight_decay
  342. )
  343. if momentum != 0:
  344. bufs = []
  345. all_states_with_momentum_buffer = True
  346. for i in range(len(device_momentum_buffer_list)):
  347. if device_momentum_buffer_list[i] is None:
  348. all_states_with_momentum_buffer = False
  349. break
  350. else:
  351. bufs.append(device_momentum_buffer_list[i])
  352. if all_states_with_momentum_buffer:
  353. torch._foreach_mul_(bufs, momentum)
  354. torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
  355. else:
  356. bufs = []
  357. for i in range(len(device_momentum_buffer_list)):
  358. if device_momentum_buffer_list[i] is None:
  359. buf = device_momentum_buffer_list[i] = momentum_buffer_list[
  360. indices[i]
  361. ] = torch.clone(device_grads[i]).detach()
  362. else:
  363. buf = device_momentum_buffer_list[i]
  364. buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
  365. bufs.append(buf)
  366. if nesterov:
  367. torch._foreach_add_(device_grads, bufs, alpha=momentum)
  368. else:
  369. device_grads = bufs
  370. if not device_has_sparse_grad:
  371. # handle internal item() call if lr is a tensor
  372. if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
  373. grads_x_lr = torch._foreach_mul(device_grads, -lr)
  374. torch._foreach_add_(device_params, grads_x_lr)
  375. else:
  376. torch._foreach_add_(device_params, device_grads, alpha=-lr)
  377. else:
  378. # foreach APIs don't support sparse
  379. for i in range(len(device_params)):
  380. device_params[i].add_(device_grads[i], alpha=-lr)
  381. def _fused_sgd(
  382. params: List[Tensor],
  383. grads: List[Tensor],
  384. momentum_buffer_list: List[Optional[Tensor]],
  385. grad_scale: Optional[Tensor],
  386. found_inf: Optional[Tensor],
  387. *,
  388. weight_decay: float,
  389. momentum: float,
  390. lr: float,
  391. dampening: float,
  392. nesterov: bool,
  393. maximize: bool,
  394. has_sparse_grad: bool,
  395. ) -> None:
  396. if not params:
  397. return
  398. if has_sparse_grad:
  399. raise RuntimeError("`_fused_sgd` does not support sparse gradients")
  400. grad_scale_dict: DeviceDict = (
  401. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  402. )
  403. found_inf_dict: DeviceDict = (
  404. {found_inf.device: found_inf} if found_inf is not None else {}
  405. )
  406. no_momentum_buffer = momentum == 0
  407. is_first_step = (
  408. all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
  409. )
  410. if is_first_step:
  411. for i, g in enumerate(grads):
  412. momentum_buffer_list[i] = torch.empty_like(g)
  413. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  414. [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
  415. )
  416. for (device, _), (
  417. (device_params, device_grads, device_momentum_buffer_list),
  418. _,
  419. ) in grouped_tensors.items():
  420. device_grad_scale, device_found_inf = None, None
  421. if grad_scale is not None:
  422. device_grad_scale = grad_scale_dict.setdefault(
  423. device, grad_scale.to(device)
  424. )
  425. if found_inf_dict is not None and found_inf is not None:
  426. device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
  427. torch._fused_sgd_(
  428. device_params,
  429. device_grads,
  430. [] if no_momentum_buffer else device_momentum_buffer_list,
  431. weight_decay=weight_decay,
  432. momentum=momentum,
  433. lr=lr,
  434. dampening=dampening,
  435. nesterov=nesterov,
  436. maximize=maximize,
  437. is_first_step=is_first_step,
  438. grad_scale=device_grad_scale,
  439. found_inf=device_found_inf,
  440. )