rprop.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _maximize_doc,
  14. _use_grad_for_differentiable,
  15. _view_as_real,
  16. Optimizer,
  17. ParamsT,
  18. )
  19. __all__ = ["Rprop", "rprop"]
  20. class Rprop(Optimizer):
  21. def __init__(
  22. self,
  23. params: ParamsT,
  24. lr: float = 1e-2,
  25. etas: Tuple[float, float] = (0.5, 1.2),
  26. step_sizes: Tuple[float, float] = (1e-6, 50),
  27. *,
  28. capturable: bool = False,
  29. foreach: Optional[bool] = None,
  30. maximize: bool = False,
  31. differentiable: bool = False,
  32. ):
  33. if not 0.0 <= lr:
  34. raise ValueError(f"Invalid learning rate: {lr}")
  35. if not 0.0 < etas[0] < 1.0 < etas[1]:
  36. raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
  37. defaults = dict(
  38. lr=lr,
  39. etas=etas,
  40. step_sizes=step_sizes,
  41. foreach=foreach,
  42. maximize=maximize,
  43. differentiable=differentiable,
  44. capturable=capturable,
  45. )
  46. super().__init__(params, defaults)
  47. def __setstate__(self, state):
  48. super().__setstate__(state)
  49. for group in self.param_groups:
  50. group.setdefault("foreach", None)
  51. group.setdefault("maximize", False)
  52. group.setdefault("differentiable", False)
  53. group.setdefault("capturable", False)
  54. for p in group["params"]:
  55. p_state = self.state.get(p, [])
  56. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  57. step_val = float(p_state["step"])
  58. p_state["step"] = (
  59. torch.tensor(
  60. step_val, dtype=_get_scalar_dtype(), device=p.device
  61. )
  62. if group["capturable"]
  63. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  64. )
  65. def _init_group(self, group, params, grads, prevs, step_sizes, state_steps):
  66. has_complex = False
  67. for p in group["params"]:
  68. if p.grad is None:
  69. continue
  70. has_complex |= torch.is_complex(p)
  71. params.append(p)
  72. grad = p.grad
  73. if grad.is_sparse:
  74. raise RuntimeError("Rprop does not support sparse gradients")
  75. grads.append(grad)
  76. state = self.state[p]
  77. # State initialization
  78. if len(state) == 0:
  79. state["step"] = (
  80. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  81. if group["capturable"]
  82. else torch.zeros((), dtype=_get_scalar_dtype())
  83. )
  84. state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  85. if p.dtype.is_complex:
  86. # Complex Number should be as if they are two independent real numbers.
  87. # Hence the step_size shouldn't be zero for imaginary part.
  88. state["step_size"] = torch.full_like(
  89. grad, complex(group["lr"], group["lr"])
  90. )
  91. else:
  92. state["step_size"] = torch.full_like(grad, group["lr"])
  93. prevs.append(state["prev"])
  94. step_sizes.append(state["step_size"])
  95. state_steps.append(state["step"])
  96. return has_complex
  97. @_use_grad_for_differentiable
  98. def step(self, closure=None):
  99. """Performs a single optimization step.
  100. Args:
  101. closure (Callable, optional): A closure that reevaluates the model
  102. and returns the loss.
  103. """
  104. self._cuda_graph_capture_health_check()
  105. loss = None
  106. if closure is not None:
  107. with torch.enable_grad():
  108. loss = closure()
  109. for group in self.param_groups:
  110. params: List[Tensor] = []
  111. grads: List[Tensor] = []
  112. prevs: List[Tensor] = []
  113. step_sizes: List[Tensor] = []
  114. state_steps: List[Tensor] = []
  115. etaminus, etaplus = group["etas"]
  116. step_size_min, step_size_max = group["step_sizes"]
  117. foreach = group["foreach"]
  118. maximize = group["maximize"]
  119. has_complex = self._init_group(
  120. group, params, grads, prevs, step_sizes, state_steps
  121. )
  122. rprop(
  123. params,
  124. grads,
  125. prevs,
  126. step_sizes,
  127. state_steps,
  128. step_size_min=step_size_min,
  129. step_size_max=step_size_max,
  130. etaminus=etaminus,
  131. etaplus=etaplus,
  132. foreach=foreach,
  133. maximize=maximize,
  134. differentiable=group["differentiable"],
  135. capturable=group["capturable"],
  136. has_complex=has_complex,
  137. )
  138. return loss
  139. Rprop.__doc__ = (
  140. r"""Implements the resilient backpropagation algorithm.
  141. .. math::
  142. \begin{aligned}
  143. &\rule{110mm}{0.4pt} \\
  144. &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
  145. \text{ (objective)}, \\
  146. &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
  147. \text{ (step sizes)} \\
  148. &\textbf{initialize} : g^0_{prev} \leftarrow 0,
  149. \: \eta_0 \leftarrow \text{lr (learning rate)} \\
  150. &\rule{110mm}{0.4pt} \\
  151. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  152. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  153. &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
  154. &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
  155. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
  156. \Gamma_{max}) \\
  157. &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
  158. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
  159. \Gamma_{min}) \\
  160. &\hspace{15mm} g^i_t \leftarrow 0 \\
  161. &\hspace{10mm} \textbf{else} \: \\
  162. &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
  163. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
  164. &\hspace{5mm}g_{prev} \leftarrow g_t \\
  165. &\rule{110mm}{0.4pt} \\[-1.ex]
  166. &\bf{return} \: \theta_t \\[-1.ex]
  167. &\rule{110mm}{0.4pt} \\[-1.ex]
  168. \end{aligned}
  169. For further details regarding the algorithm we refer to the paper
  170. `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
  171. <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
  172. """
  173. + rf"""
  174. Args:
  175. params (iterable): iterable of parameters to optimize or dicts defining
  176. parameter groups
  177. lr (float, optional): learning rate (default: 1e-2)
  178. etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
  179. are multiplicative increase and decrease factors
  180. (default: (0.5, 1.2))
  181. step_sizes (Tuple[float, float], optional): a pair of minimal and
  182. maximal allowed step sizes (default: (1e-6, 50))
  183. {_foreach_doc}
  184. {_capturable_doc}
  185. {_maximize_doc}
  186. {_differentiable_doc}
  187. """
  188. )
  189. def _single_tensor_rprop(
  190. params: List[Tensor],
  191. grads: List[Tensor],
  192. prevs: List[Tensor],
  193. step_sizes: List[Tensor],
  194. state_steps: List[Tensor],
  195. *,
  196. step_size_min: float,
  197. step_size_max: float,
  198. etaminus: float,
  199. etaplus: float,
  200. maximize: bool,
  201. capturable: bool,
  202. differentiable: bool,
  203. has_complex: bool,
  204. ):
  205. for i, param in enumerate(params):
  206. grad = grads[i]
  207. grad = grad if not maximize else -grad
  208. prev = prevs[i]
  209. step_size = step_sizes[i]
  210. step = state_steps[i]
  211. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  212. if not torch._utils.is_compiling() and capturable:
  213. capturable_supported_devices = _get_capturable_supported_devices()
  214. assert (
  215. param.device.type == step.device.type
  216. and param.device.type in capturable_supported_devices
  217. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  218. step += 1
  219. if torch.is_complex(param):
  220. grad = torch.view_as_real(grad)
  221. prev = torch.view_as_real(prev)
  222. param = torch.view_as_real(param)
  223. step_size = torch.view_as_real(step_size)
  224. if differentiable:
  225. sign = grad.mul(prev.clone()).sign()
  226. else:
  227. sign = grad.mul(prev).sign()
  228. if capturable:
  229. sign.copy_(torch.where(sign.gt(0), etaplus, sign))
  230. sign.copy_(torch.where(sign.lt(0), etaminus, sign))
  231. sign.copy_(torch.where(sign.eq(0), 1, sign))
  232. else:
  233. sign[sign.gt(0)] = etaplus
  234. sign[sign.lt(0)] = etaminus
  235. sign[sign.eq(0)] = 1
  236. # update stepsizes with step size updates
  237. step_size.mul_(sign).clamp_(step_size_min, step_size_max)
  238. # for dir<0, dfdx=0
  239. # for dir>=0 dfdx=dfdx
  240. grad = grad.clone(memory_format=torch.preserve_format)
  241. if capturable:
  242. grad.copy_(torch.where(sign.eq(etaminus), 0, grad))
  243. else:
  244. grad[sign.eq(etaminus)] = 0
  245. # update parameters
  246. param.addcmul_(grad.sign(), step_size, value=-1)
  247. prev.copy_(grad)
  248. def _multi_tensor_rprop(
  249. params: List[Tensor],
  250. grads: List[Tensor],
  251. prevs: List[Tensor],
  252. step_sizes: List[Tensor],
  253. state_steps: List[Tensor],
  254. *,
  255. step_size_min: float,
  256. step_size_max: float,
  257. etaminus: float,
  258. etaplus: float,
  259. maximize: bool,
  260. capturable: bool,
  261. differentiable: bool,
  262. has_complex: bool,
  263. ):
  264. if len(params) == 0:
  265. return
  266. assert not differentiable, "_foreach ops don't support autograd"
  267. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  268. if not torch._utils.is_compiling() and capturable:
  269. capturable_supported_devices = _get_capturable_supported_devices()
  270. assert all(
  271. p.device.type == step.device.type
  272. and p.device.type in capturable_supported_devices
  273. for p, step in zip(params, state_steps)
  274. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  275. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  276. [params, grads, prevs, step_sizes, state_steps]
  277. )
  278. for (
  279. grouped_params,
  280. grouped_grads,
  281. grouped_prevs,
  282. grouped_step_sizes,
  283. grouped_state_steps,
  284. ), _ in grouped_tensors.values():
  285. # Update steps
  286. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  287. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  288. # wrapped it once now. The alpha is required to assure we go to the right overload.
  289. if grouped_state_steps[0].is_cpu:
  290. torch._foreach_add_(
  291. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  292. )
  293. else:
  294. torch._foreach_add_(grouped_state_steps, 1)
  295. # Handle complex params
  296. if has_complex:
  297. _view_as_real(
  298. grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes
  299. )
  300. signs = torch._foreach_mul(grouped_grads, grouped_prevs)
  301. if maximize:
  302. torch._foreach_neg_(signs)
  303. # At the end of the step, grouped_prevs will contain the current grads, so we reuse
  304. # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign
  305. # to keep referring to the buffer as grouped_grads.
  306. torch._foreach_copy_(grouped_prevs, grouped_grads)
  307. if maximize:
  308. torch._foreach_neg_(grouped_prevs)
  309. grouped_grads = grouped_prevs
  310. torch._foreach_sign_(signs)
  311. if capturable:
  312. for sign in signs:
  313. sign.copy_(torch.where(sign.gt(0), etaplus, sign))
  314. sign.copy_(torch.where(sign.lt(0), etaminus, sign))
  315. sign.copy_(torch.where(sign.eq(0), 1, sign))
  316. else:
  317. for sign in signs:
  318. sign[sign.gt(0)] = etaplus
  319. sign[sign.lt(0)] = etaminus
  320. sign[sign.eq(0)] = 1
  321. # update stepsizes with step size updates
  322. torch._foreach_mul_(grouped_step_sizes, signs)
  323. for step_size in grouped_step_sizes:
  324. step_size.clamp_(step_size_min, step_size_max)
  325. # for dir<0, dfdx=0
  326. # for dir>=0 dfdx=dfdx
  327. grouped_grads = list(grouped_grads)
  328. for i in range(len(grouped_grads)):
  329. grouped_grads[i].copy_(
  330. torch.where(signs[i].eq(etaminus), 0, grouped_grads[i])
  331. )
  332. # explicitly del signs as it's not used after here to save memory
  333. del signs
  334. # update parameters
  335. grad_signs = [grad.sign() for grad in grouped_grads]
  336. torch._foreach_addcmul_(
  337. grouped_params, grad_signs, grouped_step_sizes, value=-1
  338. )
  339. # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
  340. # basically already happened since we've been using grouped_prevs' memory to store
  341. # updated grouped_grads!
  342. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
  343. def rprop(
  344. params: List[Tensor],
  345. grads: List[Tensor],
  346. prevs: List[Tensor],
  347. step_sizes: List[Tensor],
  348. state_steps: List[Tensor],
  349. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  350. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  351. foreach: Optional[bool] = None,
  352. capturable: bool = False,
  353. maximize: bool = False,
  354. differentiable: bool = False,
  355. has_complex: bool = False,
  356. *,
  357. step_size_min: float,
  358. step_size_max: float,
  359. etaminus: float,
  360. etaplus: float,
  361. ):
  362. r"""Functional API that performs rprop algorithm computation.
  363. See :class:`~torch.optim.Rprop` for details.
  364. """
  365. # this check is slow during compilation, so we skip it
  366. # if it's strictly needed we can add this check back in dynamo
  367. if not torch._utils.is_compiling() and not all(
  368. isinstance(t, torch.Tensor) for t in state_steps
  369. ):
  370. raise RuntimeError(
  371. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  372. )
  373. if foreach is None:
  374. _, foreach = _default_to_fused_or_foreach(
  375. params, differentiable, use_fused=False
  376. )
  377. if foreach and torch.jit.is_scripting():
  378. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  379. if foreach and not torch.jit.is_scripting():
  380. func = _multi_tensor_rprop
  381. else:
  382. func = _single_tensor_rprop
  383. func(
  384. params,
  385. grads,
  386. prevs,
  387. step_sizes,
  388. state_steps,
  389. step_size_min=step_size_min,
  390. step_size_max=step_size_max,
  391. etaminus=etaminus,
  392. etaplus=etaplus,
  393. capturable=capturable,
  394. maximize=maximize,
  395. differentiable=differentiable,
  396. has_complex=has_complex,
  397. )