asgd.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _get_value,
  14. _maximize_doc,
  15. _use_grad_for_differentiable,
  16. _view_as_real,
  17. Optimizer,
  18. ParamsT,
  19. )
  20. __all__ = ["ASGD", "asgd"]
  21. class ASGD(Optimizer):
  22. def __init__(
  23. self,
  24. params: ParamsT,
  25. lr: float = 1e-2,
  26. lambd: float = 1e-4,
  27. alpha: float = 0.75,
  28. t0: float = 1e6,
  29. weight_decay: float = 0,
  30. foreach: Optional[bool] = None,
  31. maximize: bool = False,
  32. differentiable: bool = False,
  33. capturable: bool = False,
  34. ):
  35. if not 0.0 <= lr:
  36. raise ValueError(f"Invalid learning rate: {lr}")
  37. if not 0.0 <= weight_decay:
  38. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  39. defaults = dict(
  40. lr=lr,
  41. lambd=lambd,
  42. alpha=alpha,
  43. t0=t0,
  44. weight_decay=weight_decay,
  45. foreach=foreach,
  46. maximize=maximize,
  47. differentiable=differentiable,
  48. capturable=capturable,
  49. )
  50. super().__init__(params, defaults)
  51. def __setstate__(self, state):
  52. super().__setstate__(state)
  53. for group in self.param_groups:
  54. group.setdefault("foreach", None)
  55. group.setdefault("maximize", False)
  56. group.setdefault("differentiable", False)
  57. group.setdefault("capturable", False)
  58. for p in group["params"]:
  59. p_state = self.state.get(p, [])
  60. if len(p_state) != 0:
  61. if not torch.is_tensor(p_state["step"]):
  62. step_val = float(p_state["step"])
  63. p_state["step"] = torch.tensor(
  64. step_val, dtype=_get_scalar_dtype(), device=p.device
  65. )
  66. if not torch.is_tensor(p_state["eta"]):
  67. p_state["eta"] = torch.tensor(
  68. p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
  69. )
  70. if not torch.is_tensor(p_state["mu"]):
  71. p_state["mu"] = torch.tensor(
  72. p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
  73. )
  74. def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
  75. has_complex = False
  76. for p in group["params"]:
  77. if p.grad is not None:
  78. has_complex |= torch.is_complex(p)
  79. params_with_grad.append(p)
  80. if p.grad.is_sparse:
  81. raise RuntimeError("ASGD does not support sparse gradients")
  82. grads.append(p.grad)
  83. state = self.state[p]
  84. # State initialization
  85. if len(state) == 0:
  86. state["step"] = torch.zeros(
  87. (), device=p.device, dtype=_get_scalar_dtype()
  88. )
  89. state["eta"] = (
  90. torch.as_tensor(
  91. group["lr"], device=p.device, dtype=_get_scalar_dtype()
  92. )
  93. .clone()
  94. .detach()
  95. )
  96. state["mu"] = torch.ones(
  97. (), device=p.device, dtype=_get_scalar_dtype()
  98. )
  99. state["ax"] = torch.zeros_like(
  100. p, memory_format=torch.preserve_format
  101. )
  102. mus.append(state["mu"])
  103. axs.append(state["ax"])
  104. etas.append(state["eta"])
  105. state_steps.append(state["step"])
  106. return has_complex
  107. @_use_grad_for_differentiable
  108. def step(self, closure=None):
  109. """Perform a single optimization step.
  110. Args:
  111. closure (Callable, optional): A closure that reevaluates the model
  112. and returns the loss.
  113. """
  114. self._cuda_graph_capture_health_check()
  115. loss = None
  116. if closure is not None:
  117. with torch.enable_grad():
  118. loss = closure()
  119. for group in self.param_groups:
  120. params_with_grad: List[Tensor] = []
  121. grads: List[Tensor] = []
  122. mus: List[Tensor] = []
  123. axs: List[Tensor] = []
  124. etas: List[Tensor] = []
  125. state_steps: List[Tensor] = []
  126. has_complex = self._init_group(
  127. group, params_with_grad, grads, mus, axs, etas, state_steps
  128. )
  129. asgd(
  130. params_with_grad,
  131. grads,
  132. axs,
  133. mus,
  134. etas,
  135. state_steps,
  136. lambd=group["lambd"],
  137. lr=group["lr"],
  138. t0=group["t0"],
  139. alpha=group["alpha"],
  140. weight_decay=group["weight_decay"],
  141. foreach=group["foreach"],
  142. maximize=group["maximize"],
  143. differentiable=group["differentiable"],
  144. capturable=group["capturable"],
  145. has_complex=has_complex,
  146. )
  147. return loss
  148. ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
  149. It has been proposed in `Acceleration of stochastic approximation by
  150. averaging`_.
  151. Args:
  152. params (iterable): iterable of parameters to optimize or dicts defining
  153. parameter groups
  154. lr (float, optional): learning rate (default: 1e-2)
  155. lambd (float, optional): decay term (default: 1e-4)
  156. alpha (float, optional): power for eta update (default: 0.75)
  157. t0 (float, optional): point at which to start averaging (default: 1e6)
  158. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  159. {_foreach_doc}
  160. {_maximize_doc}
  161. {_differentiable_doc}
  162. {_capturable_doc}
  163. .. _Acceleration of stochastic approximation by averaging:
  164. https://dl.acm.org/citation.cfm?id=131098
  165. """
  166. def _single_tensor_asgd(
  167. params: List[Tensor],
  168. grads: List[Tensor],
  169. axs: List[Tensor],
  170. mus: List[Tensor],
  171. etas: List[Tensor],
  172. state_steps: List[Tensor],
  173. *,
  174. lambd: float,
  175. lr: float,
  176. t0: float,
  177. alpha: float,
  178. weight_decay: float,
  179. maximize: bool,
  180. differentiable: bool,
  181. capturable: bool,
  182. has_complex: bool,
  183. ):
  184. for i, param in enumerate(params):
  185. grad = grads[i]
  186. grad = grad if not maximize else -grad
  187. mu = mus[i]
  188. ax = axs[i]
  189. eta = etas[i]
  190. step_t = state_steps[i]
  191. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  192. if not torch._utils.is_compiling() and capturable:
  193. capturable_supported_devices = _get_capturable_supported_devices()
  194. assert (
  195. param.device.type
  196. == mu.device.type
  197. == eta.device.type
  198. == step_t.device.type
  199. and param.device.type in capturable_supported_devices
  200. ), (
  201. f"If capturable=True, params, mus, etas, and state_steps must be "
  202. f"on supported devices: {capturable_supported_devices}."
  203. )
  204. if torch.is_complex(param):
  205. grad = torch.view_as_real(grad)
  206. param = torch.view_as_real(param)
  207. ax = torch.view_as_real(ax)
  208. # update step
  209. step_t += 1
  210. if weight_decay != 0:
  211. grad = grad.add(param, alpha=weight_decay)
  212. if capturable:
  213. param.mul_(1 - lambd * eta)
  214. param.addcmul_(grad, eta, value=-1) # update parameter
  215. else:
  216. eta_value = _get_value(eta)
  217. param.mul_(1 - lambd * eta_value) # decay term
  218. param.add_(grad, alpha=-eta_value) # update parameter
  219. # averaging
  220. if capturable or mu.item() != 1:
  221. ax.add_(param.sub(ax).mul_(mu))
  222. else:
  223. ax.copy_(param)
  224. if capturable:
  225. eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
  226. mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
  227. else:
  228. step = _get_value(step_t)
  229. new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
  230. eta.copy_(new_eta)
  231. new_mu = torch.as_tensor(1 / max(1, step - t0))
  232. mu.copy_(new_mu)
  233. def _multi_tensor_asgd(
  234. params: List[Tensor],
  235. grads: List[Tensor],
  236. axs: List[Tensor],
  237. mus: List[Tensor],
  238. etas: List[Tensor],
  239. state_steps: List[Tensor],
  240. *,
  241. lambd: float,
  242. lr: float,
  243. t0: float,
  244. alpha: float,
  245. weight_decay: float,
  246. maximize: bool,
  247. differentiable: bool,
  248. capturable: bool,
  249. has_complex: bool,
  250. ):
  251. if len(params) == 0:
  252. return
  253. assert not differentiable, "_foreach ops don't support autograd"
  254. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  255. if not torch._utils.is_compiling() and capturable:
  256. capturable_supported_devices = _get_capturable_supported_devices(
  257. supports_xla=False
  258. )
  259. assert all(
  260. p.device.type == mu.device.type == eta.device.type == step.device.type
  261. and p.device.type in capturable_supported_devices
  262. for p, mu, eta, step in zip(params, mus, etas, state_steps)
  263. ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
  264. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  265. [params, grads, axs, mus, etas, state_steps]
  266. )
  267. for (device, _), (
  268. (
  269. grouped_params,
  270. grouped_grads,
  271. grouped_axs,
  272. grouped_mus,
  273. grouped_etas,
  274. grouped_state_steps,
  275. ),
  276. _,
  277. ) in grouped_tensors.items():
  278. if has_complex:
  279. _view_as_real(grouped_params, grouped_grads, grouped_axs)
  280. if maximize:
  281. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  282. # Update steps
  283. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  284. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  285. # wrapped it once now. The alpha is required to assure we go to the right overload.
  286. if grouped_state_steps[0].is_cpu:
  287. torch._foreach_add_(
  288. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  289. )
  290. else:
  291. torch._foreach_add_(grouped_state_steps, 1)
  292. # intermediate = grad + param * lambd
  293. intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
  294. if weight_decay != 0:
  295. if maximize:
  296. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  297. intermediate = grouped_grads
  298. else:
  299. intermediate = torch._foreach_add(
  300. grouped_grads, grouped_params, alpha=weight_decay
  301. )
  302. torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
  303. else:
  304. intermediate = torch._foreach_add(
  305. grouped_grads, grouped_params, alpha=lambd
  306. )
  307. # update param
  308. # param * (1 - lambd * eta) - eta * grad
  309. # => param - param * lambd * eta - eta * grad
  310. # => param - eta * intermediate
  311. torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
  312. del intermediate
  313. # update grouped_axs
  314. # averaging: ax = ax + mu * (param - ax)
  315. # Note (mlazos): We can't use lerp here since it requires weight to be float64
  316. # and our grouping code requires dtypes to match for all tensors in a group (and it should, since
  317. # we use the mus in other places)
  318. # all dtypes need to match, so we could introduce a cast in a loop
  319. # but since this only adds one additional kernel launch, this looks like the cleaner
  320. # and faster solution
  321. intermediate = torch._foreach_sub(grouped_params, grouped_axs)
  322. torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
  323. del intermediate
  324. new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
  325. new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
  326. if capturable:
  327. # update grouped_mus
  328. new_mus = torch._foreach_sub(grouped_state_steps, t0)
  329. torch._foreach_maximum_(new_mus, 1.0)
  330. torch._foreach_reciprocal_(new_mus)
  331. torch._foreach_copy_(grouped_mus, new_mus)
  332. del new_mus
  333. # update eta = lr / ((1 + lambd * lr * step)^alpha)
  334. new_etas = torch._foreach_mul(grouped_state_steps, lambd)
  335. torch._foreach_mul_(new_etas, lr)
  336. torch._foreach_add_(new_etas, 1)
  337. torch._foreach_pow_(new_etas, alpha)
  338. torch._foreach_reciprocal_(new_etas)
  339. torch._foreach_mul_(new_etas, lr)
  340. torch._foreach_copy_(grouped_etas, new_etas)
  341. else:
  342. new_etas = [
  343. torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
  344. for step in grouped_state_steps
  345. ]
  346. new_mus = [
  347. torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
  348. for step in grouped_state_steps
  349. ]
  350. torch._foreach_copy_(grouped_etas, new_etas)
  351. torch._foreach_copy_(grouped_mus, new_mus)
  352. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
  353. def asgd(
  354. params: List[Tensor],
  355. grads: List[Tensor],
  356. axs: List[Tensor],
  357. mus: List[Tensor],
  358. etas: List[Tensor],
  359. state_steps: List[Tensor],
  360. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  361. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  362. foreach: Optional[bool] = None,
  363. maximize: bool = False,
  364. differentiable: bool = False,
  365. capturable: bool = False,
  366. has_complex: bool = False,
  367. *,
  368. lambd: float,
  369. lr: float,
  370. t0: float,
  371. alpha: float,
  372. weight_decay: float,
  373. ):
  374. r"""Functional API that performs asgd algorithm computation.
  375. See :class:`~torch.optim.ASGD` for details.
  376. """
  377. if foreach is None:
  378. _, foreach = _default_to_fused_or_foreach(
  379. params, differentiable, use_fused=False
  380. )
  381. if foreach and torch.jit.is_scripting():
  382. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  383. if foreach and not torch.jit.is_scripting():
  384. func = _multi_tensor_asgd
  385. else:
  386. func = _single_tensor_asgd
  387. func(
  388. params,
  389. grads,
  390. axs,
  391. mus,
  392. etas,
  393. state_steps,
  394. lambd=lambd,
  395. lr=lr,
  396. t0=t0,
  397. alpha=alpha,
  398. weight_decay=weight_decay,
  399. maximize=maximize,
  400. differentiable=differentiable,
  401. capturable=capturable,
  402. has_complex=has_complex,
  403. )