optimizer.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import warnings
  5. from collections import defaultdict, OrderedDict
  6. from copy import deepcopy
  7. from itertools import chain
  8. from typing import (
  9. Any,
  10. Callable,
  11. cast,
  12. DefaultDict,
  13. Dict,
  14. Hashable,
  15. Iterable,
  16. List,
  17. Optional,
  18. overload,
  19. Set,
  20. Tuple,
  21. TypeVar,
  22. Union,
  23. )
  24. from typing_extensions import ParamSpec, Self, TypeAlias
  25. import torch
  26. import torch.utils.hooks as hooks
  27. from torch._utils import is_compiling
  28. from torch.utils._foreach_utils import (
  29. _get_foreach_kernels_supported_devices,
  30. _get_fused_kernels_supported_devices,
  31. _group_tensors_by_device_and_dtype,
  32. Indices,
  33. )
  34. from torch.utils.hooks import RemovableHandle
  35. Args: TypeAlias = Tuple[Any, ...]
  36. Kwargs: TypeAlias = Dict[str, Any]
  37. StateDict: TypeAlias = Dict[str, Any]
  38. TensorListList: TypeAlias = List[List[torch.Tensor]]
  39. DeviceDict = Dict[Optional[torch.device], torch.Tensor]
  40. GlobalOptimizerPreHook: TypeAlias = Callable[
  41. ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]
  42. ]
  43. GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
  44. __all__ = [
  45. "Optimizer",
  46. "register_optimizer_step_pre_hook",
  47. "register_optimizer_step_post_hook",
  48. ]
  49. _global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
  50. _global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
  51. _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
  52. class _RequiredParameter:
  53. """Singleton class representing a required parameter for an Optimizer."""
  54. def __repr__(self) -> str:
  55. return "<required parameter>"
  56. required = _RequiredParameter()
  57. def _use_grad_for_differentiable(func):
  58. def _use_grad(self, *args, **kwargs):
  59. import torch._dynamo
  60. prev_grad = torch.is_grad_enabled()
  61. try:
  62. # Note on graph break below:
  63. # we need to graph break to ensure that aot respects the no_grad annotation.
  64. # This is important for perf because without this, functionalization will generate an epilogue
  65. # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
  66. # inductor will allocate for every parameter in the model, which is horrible.
  67. # With this, aot correctly sees that this is an inference graph, and functionalization will generate
  68. # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
  69. # step is in place and is able to avoid the extra allocation.
  70. # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
  71. # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
  72. # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
  73. # see https://github.com/pytorch/pytorch/issues/104053
  74. torch.set_grad_enabled(self.defaults["differentiable"])
  75. torch._dynamo.graph_break()
  76. ret = func(self, *args, **kwargs)
  77. finally:
  78. torch._dynamo.graph_break()
  79. torch.set_grad_enabled(prev_grad)
  80. return ret
  81. functools.update_wrapper(_use_grad, func)
  82. return _use_grad
  83. def _get_value(x):
  84. # item is significantly faster than a cpu tensor in eager mode
  85. if not torch.jit.is_scripting() and is_compiling():
  86. return x
  87. else:
  88. return x.item() if isinstance(x, torch.Tensor) else x
  89. def _stack_if_compiling(x):
  90. if not torch.jit.is_scripting() and is_compiling():
  91. return torch.stack(x)
  92. else:
  93. return x
  94. def _dispatch_sqrt(
  95. x: float,
  96. ): # float annotation is needed because of torchscript type inference
  97. if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
  98. return x.sqrt()
  99. else:
  100. return math.sqrt(x)
  101. def _disable_dynamo_if_unsupported(single_tensor_fn=None):
  102. # workaround for torchscript BC
  103. # it requires all called functions to be in the
  104. # global environment at the site at which the
  105. # maybe_fallback closure is created
  106. if single_tensor_fn:
  107. globals()[single_tensor_fn.__name__] = single_tensor_fn
  108. def wrapper(func):
  109. import inspect
  110. disabled_func = torch._disable_dynamo(func)
  111. ps = inspect.signature(func).parameters
  112. has_state_steps = True
  113. try:
  114. state_steps_ind = list(ps.keys()).index("state_steps")
  115. except ValueError:
  116. has_state_steps = False
  117. # Today, there are cases where we stack state steps
  118. # and pass them as the value arg of foreach ops.
  119. # Having state steps on cuda as the value arg is not supported in eager,
  120. # but this only occurs in the rare case that the user explicitly deletes
  121. # the capturable flag. If capturable=True, this is not a problem.
  122. @functools.wraps(func)
  123. def maybe_fallback(*args, **kwargs):
  124. if is_compiling() and (
  125. not kwargs.get("capturable", False)
  126. and has_state_steps
  127. and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
  128. or (
  129. "state_steps" in kwargs
  130. and kwargs["state_steps"]
  131. and kwargs["state_steps"][0].is_cuda
  132. )
  133. ):
  134. return disabled_func(*args, **kwargs)
  135. else:
  136. return func(*args, **kwargs)
  137. return maybe_fallback
  138. return wrapper
  139. # For any optimizer with a faster implementation, we attempt to default to the
  140. # fastest + stablest whenever possible. For foreach, the requirements are to have
  141. # native params all on CUDA. For fused, there's currently the additional requirement
  142. # that the tensors' dtypes must be floating point. Neither alternative supports
  143. # torch.jit.script nor differentiable, so we fall back to the single tensor
  144. # implementation in those cases.
  145. def _default_to_fused_or_foreach(
  146. params: List[torch.Tensor], differentiable: bool, use_fused: bool = False
  147. ) -> Tuple[bool, bool]:
  148. if torch.jit.is_scripting() or differentiable:
  149. return False, False
  150. fused_supported_devices = _get_fused_kernels_supported_devices()
  151. foreach_supported_devices = _get_foreach_kernels_supported_devices()
  152. fused = use_fused and all(
  153. p is None
  154. or (
  155. type(p) in _foreach_supported_types
  156. and p.device.type in fused_supported_devices
  157. and torch.is_floating_point(p)
  158. )
  159. for p in params
  160. )
  161. foreach = not fused and all(
  162. p is None
  163. or (
  164. type(p) in _foreach_supported_types
  165. and p.device.type in foreach_supported_devices
  166. )
  167. for p in params
  168. )
  169. return fused, foreach
  170. def _view_as_real(params, *state_and_grads):
  171. for i, p in enumerate(params):
  172. if torch.is_complex(p):
  173. params[i] = torch.view_as_real(params[i])
  174. for s in state_and_grads:
  175. s[i] = torch.view_as_real(s[i])
  176. def _get_scalar_dtype(is_fused=None):
  177. if is_fused:
  178. return torch.float32
  179. return (
  180. torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
  181. )
  182. def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
  183. r"""Return the device type list that supports capturable optimizer."""
  184. capturable_supported_devices = ["cuda"]
  185. if not torch.jit.is_scripting():
  186. capturable_supported_devices.append(torch._C._get_privateuse1_backend_name())
  187. if supports_xla:
  188. capturable_supported_devices.append("xla")
  189. return capturable_supported_devices
  190. # Common doc strings among optimizers
  191. _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
  192. is used. If unspecified by the user (so foreach is None), we will try to use
  193. foreach over the for-loop implementation on CUDA, since it is usually
  194. significantly more performant. Note that the foreach implementation uses
  195. ~ sizeof(params) more peak memory than the for-loop version due to the intermediates
  196. being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer
  197. parameters through the optimizer at a time or switch this flag to False (default: None)"""
  198. _fused_doc = r"""fused (bool, optional): whether the fused implementation is used.
  199. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
  200. are supported. (default: None)
  201. .. note:: The foreach and fused implementations are typically faster than the for-loop,
  202. single-tensor implementation. Thus, if the user has not specified BOTH flags
  203. (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
  204. implementation when the tensors are all on CUDA. For example, if the user specifies
  205. True for fused but nothing for foreach, we will run the fused implementation. If
  206. the user specifies False for foreach but nothing for fused (or False for fused but
  207. nothing for foreach), we will run the for-loop implementation. If the user specifies
  208. True for both foreach and fused, we will prioritize fused over foreach, as it is
  209. typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
  210. foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
  211. we want to give it sufficient bake-in time, so we default to foreach and NOT
  212. fused when the user has not specified either flag."""
  213. _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
  214. capture in a CUDA graph. Passing True can impair ungraphed performance,
  215. so if you don't intend to graph capture this instance, leave it False
  216. (default: False)"""
  217. _differentiable_doc = r"""differentiable (bool, optional): whether autograd should
  218. occur through the optimizer step in training. Otherwise, the step()
  219. function runs in a torch.no_grad() context. Setting to True can impair
  220. performance, so leave it False if you don't intend to run autograd
  221. through this instance (default: False)"""
  222. _maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the
  223. params, instead of minimizing (default: False)"""
  224. def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle:
  225. r"""Register a pre hook common to all optimizers. The hook should have the following
  226. signature::
  227. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  228. Args:
  229. hook (Callable): A user defined hook which is registered on all optimizers.
  230. Returns:
  231. :class:`torch.utils.hooks.RemovableHandle`:
  232. a handle that can be used to remove the added hook by calling
  233. ``handle.remove()``
  234. """
  235. handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
  236. _global_optimizer_pre_hooks[handle.id] = hook
  237. return handle
  238. def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle:
  239. r"""Register a post hook common to all optimizers. The hook should have the following
  240. signature::
  241. hook(optimizer, args, kwargs) -> None
  242. Args:
  243. hook (Callable): A user defined hook which is registered on all optimizers.
  244. Returns:
  245. :class:`torch.utils.hooks.RemovableHandle`:
  246. a handle that can be used to remove the added hook by calling
  247. ``handle.remove()``
  248. """
  249. handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
  250. _global_optimizer_post_hooks[handle.id] = hook
  251. return handle
  252. ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
  253. _P = ParamSpec("_P")
  254. R = TypeVar("R")
  255. T = TypeVar("T")
  256. class Optimizer:
  257. r"""Base class for all optimizers.
  258. .. warning::
  259. Parameters need to be specified as collections that have a deterministic
  260. ordering that is consistent between runs. Examples of objects that don't
  261. satisfy those properties are sets and iterators over values of dictionaries.
  262. Args:
  263. params (iterable): an iterable of :class:`torch.Tensor` s or
  264. :class:`dict` s. Specifies what Tensors should be optimized.
  265. defaults: (dict): a dict containing default values of optimization
  266. options (used when a parameter group doesn't specify them).
  267. """
  268. OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc]
  269. OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
  270. _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
  271. _optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
  272. _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
  273. _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
  274. _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
  275. _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
  276. def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
  277. torch._C._log_api_usage_once("python.optimizer")
  278. self.defaults = defaults
  279. self._optimizer_step_pre_hooks = OrderedDict()
  280. self._optimizer_step_post_hooks = OrderedDict()
  281. self._optimizer_state_dict_pre_hooks = OrderedDict()
  282. self._optimizer_state_dict_post_hooks = OrderedDict()
  283. self._optimizer_load_state_dict_pre_hooks = OrderedDict()
  284. self._optimizer_load_state_dict_post_hooks = OrderedDict()
  285. self._patch_step_function()
  286. if isinstance(params, torch.Tensor):
  287. raise TypeError(
  288. "params argument given to the optimizer should be "
  289. "an iterable of Tensors or dicts, but got " + torch.typename(params)
  290. )
  291. self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
  292. self.param_groups: List[Dict[str, Any]] = []
  293. param_groups = list(params)
  294. if len(param_groups) == 0:
  295. raise ValueError("optimizer got an empty parameter list")
  296. if not isinstance(param_groups[0], dict):
  297. param_groups = [{"params": param_groups}]
  298. for param_group in param_groups:
  299. self.add_param_group(cast(dict, param_group))
  300. # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
  301. # which I don't think exists
  302. # https://github.com/pytorch/pytorch/issues/72948
  303. self._warned_capturable_if_run_uncaptured = True
  304. def __getstate__(self) -> Dict[str, Any]:
  305. return {
  306. "defaults": self.defaults,
  307. "state": self.state,
  308. "param_groups": self.param_groups,
  309. }
  310. def __setstate__(self, state: Dict[str, Any]) -> None:
  311. self.__dict__.update(state)
  312. if "_optimizer_step_pre_hooks" not in self.__dict__:
  313. self._optimizer_step_pre_hooks = OrderedDict()
  314. if "_optimizer_step_post_hooks" not in self.__dict__:
  315. self._optimizer_step_post_hooks = OrderedDict()
  316. if "_optimizer_state_dict_pre_hooks" not in self.__dict__:
  317. self._optimizer_state_dict_pre_hooks = OrderedDict()
  318. if "_optimizer_state_dict_post_hooks" not in self.__dict__:
  319. self._optimizer_state_dict_post_hooks = OrderedDict()
  320. if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__:
  321. self._optimizer_load_state_dict_pre_hooks = OrderedDict()
  322. if "_optimizer_load_state_dict_post_hooks" not in self.__dict__:
  323. self._optimizer_load_state_dict_post_hooks = OrderedDict()
  324. self._patch_step_function() # To support multiprocessing pickle/unpickle
  325. self.defaults.setdefault("differentiable", False)
  326. def __repr__(self) -> str:
  327. format_string = self.__class__.__name__ + " ("
  328. for i, group in enumerate(self.param_groups):
  329. format_string += "\n"
  330. format_string += f"Parameter Group {i}\n"
  331. for key in sorted(group.keys()):
  332. if key != "params":
  333. format_string += f" {key}: {group[key]}\n"
  334. format_string += ")"
  335. return format_string
  336. # Currently needed by Adam and AdamW
  337. def _cuda_graph_capture_health_check(self) -> None:
  338. # Note [torch.compile x capturable]
  339. # If we are compiling, we try to take the capturable path automatically by
  340. # setting the flag to True during tracing. Due to this, we skip all the checks
  341. # normally required for determining whether we can use CUDA graphs and
  342. # shunt the responsibility to torch.inductor. This saves time during tracing
  343. # since the checks are slow without sacrificing UX since inductor will warn
  344. # later if CUDA graphs cannot be enabled, e.g.,
  345. # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
  346. # Thus, when compiling, inductor will determine if cudagraphs
  347. # can be enabled based on whether there is input mutation or CPU tensors.
  348. if (
  349. not is_compiling()
  350. and torch.backends.cuda.is_built()
  351. and torch.cuda.is_available()
  352. ):
  353. capturing = torch.cuda.is_current_stream_capturing()
  354. if capturing and not all(
  355. group["capturable"] for group in self.param_groups
  356. ):
  357. raise RuntimeError(
  358. "Attempting CUDA graph capture of step() for an instance of "
  359. + self.__class__.__name__
  360. + " but param_groups' capturable is False."
  361. )
  362. if (
  363. (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
  364. and all(group["capturable"] for group in self.param_groups)
  365. and (not capturing)
  366. ):
  367. warnings.warn(
  368. "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
  369. "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
  370. "instance, capturable=True can impair performance, and you should set capturable=False."
  371. )
  372. self._warned_capturable_if_run_uncaptured = True
  373. def _optimizer_step_code(self) -> None:
  374. """Entry point for `torch.profile.profiler`.
  375. When python tracing is enabled the profiler will hook into this
  376. function at the CPython level to inspect the optimizer's parameters and
  377. param groups. It is called it after `step()` since many optimizers
  378. lazily initialize state.
  379. This is a workaround due to lack of a proper step hook on the optimizer,
  380. and will be removed if it exists.
  381. """
  382. pass
  383. @staticmethod
  384. def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]:
  385. @functools.wraps(func)
  386. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R:
  387. self, *_ = args
  388. self = cast(Optimizer, self)
  389. profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
  390. with torch.autograd.profiler.record_function(profile_name):
  391. # call optimizer step pre hooks
  392. for pre_hook in chain(
  393. _global_optimizer_pre_hooks.values(),
  394. self._optimizer_step_pre_hooks.values(),
  395. ):
  396. result = pre_hook(self, args, kwargs)
  397. if result is not None:
  398. if isinstance(result, tuple) and len(result) == 2:
  399. args, kwargs = result # type: ignore[assignment]
  400. else:
  401. raise RuntimeError(
  402. f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
  403. )
  404. out = func(*args, **kwargs)
  405. self._optimizer_step_code()
  406. # call optimizer step post hooks
  407. for post_hook in chain(
  408. self._optimizer_step_post_hooks.values(),
  409. _global_optimizer_post_hooks.values(),
  410. ):
  411. post_hook(self, args, kwargs)
  412. return out
  413. return wrapper
  414. @staticmethod
  415. def _group_tensors_by_device_and_dtype(
  416. tensorlistlist: TensorListList,
  417. with_indices: bool = False,
  418. ) -> Union[
  419. Dict[Tuple[None, None], Tuple[TensorListList, Indices]],
  420. Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]],
  421. ]:
  422. """Groups a list of lists of tensors by device and dtype.
  423. Skips this step if we are compiling since this will occur during inductor lowering.
  424. """
  425. if is_compiling():
  426. return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
  427. else:
  428. return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
  429. def _patch_step_function(self) -> None:
  430. self._zero_grad_profile_name = (
  431. f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad"
  432. )
  433. hooked = getattr(self.__class__.step, "hooked", None)
  434. if not hooked:
  435. self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment]
  436. self.__class__.step.hooked = True # type: ignore[attr-defined]
  437. def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle:
  438. r"""Register an optimizer step pre hook which will be called before
  439. optimizer step. It should have the following signature::
  440. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  441. The ``optimizer`` argument is the optimizer instance being used. If
  442. args and kwargs are modified by the pre-hook, then the transformed
  443. values are returned as a tuple containing the new_args and new_kwargs.
  444. Args:
  445. hook (Callable): The user defined hook to be registered.
  446. Returns:
  447. :class:`torch.utils.hooks.RemovableHandle`:
  448. a handle that can be used to remove the added hook by calling
  449. ``handle.remove()``
  450. """
  451. handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
  452. self._optimizer_step_pre_hooks[handle.id] = hook
  453. return handle
  454. def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
  455. r"""Register an optimizer step post hook which will be called after optimizer step.
  456. It should have the following signature::
  457. hook(optimizer, args, kwargs) -> None
  458. The ``optimizer`` argument is the optimizer instance being used.
  459. Args:
  460. hook (Callable): The user defined hook to be registered.
  461. Returns:
  462. :class:`torch.utils.hooks.RemovableHandle`:
  463. a handle that can be used to remove the added hook by calling
  464. ``handle.remove()``
  465. """
  466. handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
  467. self._optimizer_step_post_hooks[handle.id] = hook
  468. return handle
  469. def register_state_dict_pre_hook(
  470. self, hook: Callable[["Optimizer"], None], prepend: bool = False
  471. ) -> RemovableHandle:
  472. r"""Register a state dict pre-hook which will be called before
  473. :meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
  474. following signature::
  475. hook(optimizer) -> None
  476. The ``optimizer`` argument is the optimizer instance being used.
  477. The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
  478. The registered hook can be used to perform pre-processing before the ``state_dict``
  479. call is made.
  480. Args:
  481. hook (Callable): The user defined hook to be registered.
  482. prepend (bool): If True, the provided pre ``hook`` will be fired before
  483. all the already registered pre-hooks on ``state_dict``. Otherwise,
  484. the provided ``hook`` will be fired after all the already registered
  485. pre-hooks. (default: False)
  486. Returns:
  487. :class:`torch.utils.hooks.RemoveableHandle`:
  488. a handle that can be used to remove the added hook by calling
  489. ``handle.remove()``
  490. """
  491. handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
  492. self._optimizer_state_dict_pre_hooks[handle.id] = hook
  493. if prepend:
  494. self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
  495. return handle
  496. def register_state_dict_post_hook(
  497. self,
  498. hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
  499. prepend: bool = False,
  500. ) -> RemovableHandle:
  501. r"""Register a state dict post-hook which will be called after
  502. :meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
  503. following signature::
  504. hook(optimizer, state_dict) -> state_dict or None
  505. The hook will be called with arguments ``self`` and ``state_dict`` after generating
  506. a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
  507. return a new one. The registered hook can be used to perform post-processing
  508. on the ``state_dict`` before it is returned.
  509. Args:
  510. hook (Callable): The user defined hook to be registered.
  511. prepend (bool): If True, the provided post ``hook`` will be fired before
  512. all the already registered post-hooks on ``state_dict``. Otherwise,
  513. the provided ``hook`` will be fired after all the already registered
  514. post-hooks. (default: False)
  515. Returns:
  516. :class:`torch.utils.hooks.RemoveableHandle`:
  517. a handle that can be used to remove the added hook by calling
  518. ``handle.remove()``
  519. """
  520. handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
  521. self._optimizer_state_dict_post_hooks[handle.id] = hook
  522. if prepend:
  523. self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
  524. return handle
  525. @torch._disable_dynamo
  526. def state_dict(self) -> StateDict:
  527. r"""Returns the state of the optimizer as a :class:`dict`.
  528. It contains two entries:
  529. * ``state``: a Dict holding current optimization state. Its content
  530. differs between optimizer classes, but some common characteristics
  531. hold. For example, state is saved per parameter, and the parameter
  532. itself is NOT saved. ``state`` is a Dictionary mapping parameter ids
  533. to a Dict with state corresponding to each parameter.
  534. * ``param_groups``: a List containing all parameter groups where each
  535. parameter group is a Dict. Each parameter group contains metadata
  536. specific to the optimizer, such as learning rate and weight decay,
  537. as well as a List of parameter IDs of the parameters in the group.
  538. NOTE: The parameter IDs may look like indices but they are just IDs
  539. associating state with param_group. When loading from a state_dict,
  540. the optimizer will zip the param_group ``params`` (int IDs) and the
  541. optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to
  542. match state WITHOUT additional verification.
  543. A returned state dict might look something like:
  544. .. code-block:: text
  545. {
  546. 'state': {
  547. 0: {'momentum_buffer': tensor(...), ...},
  548. 1: {'momentum_buffer': tensor(...), ...},
  549. 2: {'momentum_buffer': tensor(...), ...},
  550. 3: {'momentum_buffer': tensor(...), ...}
  551. },
  552. 'param_groups': [
  553. {
  554. 'lr': 0.01,
  555. 'weight_decay': 0,
  556. ...
  557. 'params': [0]
  558. },
  559. {
  560. 'lr': 0.001,
  561. 'weight_decay': 0.5,
  562. ...
  563. 'params': [1, 2, 3]
  564. }
  565. ]
  566. }
  567. """
  568. for pre_hook in self._optimizer_state_dict_pre_hooks.values():
  569. pre_hook(self)
  570. # Save order indices instead of Tensors
  571. param_mappings: Dict[int, int] = {}
  572. start_index = 0
  573. def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
  574. nonlocal start_index
  575. packed = {k: v for k, v in group.items() if k != "params"}
  576. param_mappings.update(
  577. {
  578. id(p): i
  579. for i, p in enumerate(group["params"], start_index)
  580. if id(p) not in param_mappings
  581. }
  582. )
  583. packed["params"] = [param_mappings[id(p)] for p in group["params"]]
  584. start_index += len(packed["params"])
  585. return packed
  586. param_groups = [pack_group(g) for g in self.param_groups]
  587. # Remap state to use order indices as keys
  588. packed_state = {
  589. (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
  590. for k, v in self.state.items()
  591. }
  592. state_dict = {
  593. "state": packed_state,
  594. "param_groups": param_groups,
  595. }
  596. for post_hook in self._optimizer_state_dict_post_hooks.values():
  597. hook_result = post_hook(self, state_dict)
  598. if hook_result is not None:
  599. state_dict = hook_result
  600. return state_dict
  601. @staticmethod
  602. def _process_value_according_to_param_policy(
  603. param: torch.Tensor,
  604. value: torch.Tensor,
  605. param_id: int,
  606. param_groups: List[Dict[Any, Any]],
  607. key: Hashable = None,
  608. ) -> torch.Tensor:
  609. # Floating-point types are a bit special here. They are the only ones
  610. # that are assumed to always match the type of params.
  611. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
  612. # UNLESS fused or capturable, see note [special device hosting for step]
  613. fused = False
  614. capturable = False
  615. assert param_groups is not None
  616. for pg in param_groups:
  617. if param_id in pg["params"]:
  618. fused = pg["fused"] if "fused" in pg else False
  619. capturable = pg["capturable"] if "capturable" in pg else False
  620. break
  621. if key == "step":
  622. if capturable or fused:
  623. return value.to(dtype=torch.float32, device=param.device)
  624. else:
  625. return value
  626. else:
  627. if param.is_floating_point():
  628. return value.to(dtype=param.dtype, device=param.device)
  629. else:
  630. return value.to(device=param.device)
  631. def register_load_state_dict_pre_hook(
  632. self,
  633. hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
  634. prepend: bool = False,
  635. ) -> RemovableHandle:
  636. r"""Register a load_state_dict pre-hook which will be called before
  637. :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
  638. following signature::
  639. hook(optimizer, state_dict) -> state_dict or None
  640. The ``optimizer`` argument is the optimizer instance being used and the
  641. ``state_dict`` argument is a shallow copy of the ``state_dict`` the user
  642. passed in to ``load_state_dict``. The hook may modify the state_dict inplace
  643. or optionally return a new one. If a state_dict is returned, it will be used
  644. to be loaded into the optimizer.
  645. The hook will be called with argument ``self`` and ``state_dict`` before
  646. calling ``load_state_dict`` on ``self``. The registered hook can be used to
  647. perform pre-processing before the ``load_state_dict`` call is made.
  648. Args:
  649. hook (Callable): The user defined hook to be registered.
  650. prepend (bool): If True, the provided pre ``hook`` will be fired before
  651. all the already registered pre-hooks on ``load_state_dict``. Otherwise,
  652. the provided ``hook`` will be fired after all the already registered
  653. pre-hooks. (default: False)
  654. Returns:
  655. :class:`torch.utils.hooks.RemoveableHandle`:
  656. a handle that can be used to remove the added hook by calling
  657. ``handle.remove()``
  658. """
  659. handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
  660. self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
  661. if prepend:
  662. self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
  663. return handle
  664. def register_load_state_dict_post_hook(
  665. self, hook: Callable[["Optimizer"], None], prepend: bool = False
  666. ) -> RemovableHandle:
  667. r"""Register a load_state_dict post-hook which will be called after
  668. :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
  669. following signature::
  670. hook(optimizer) -> None
  671. The ``optimizer`` argument is the optimizer instance being used.
  672. The hook will be called with argument ``self`` after calling
  673. ``load_state_dict`` on ``self``. The registered hook can be used to
  674. perform post-processing after ``load_state_dict`` has loaded the
  675. ``state_dict``.
  676. Args:
  677. hook (Callable): The user defined hook to be registered.
  678. prepend (bool): If True, the provided post ``hook`` will be fired before
  679. all the already registered post-hooks on ``load_state_dict``. Otherwise,
  680. the provided ``hook`` will be fired after all the already registered
  681. post-hooks. (default: False)
  682. Returns:
  683. :class:`torch.utils.hooks.RemoveableHandle`:
  684. a handle that can be used to remove the added hook by calling
  685. ``handle.remove()``
  686. """
  687. handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
  688. self._optimizer_load_state_dict_post_hooks[handle.id] = hook
  689. if prepend:
  690. self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
  691. return handle
  692. @torch._disable_dynamo
  693. def load_state_dict(self, state_dict: StateDict) -> None:
  694. r"""Loads the optimizer state.
  695. Args:
  696. state_dict (dict): optimizer state. Should be an object returned
  697. from a call to :meth:`state_dict`.
  698. """
  699. # shallow copy, to be consistent with module API
  700. state_dict = state_dict.copy()
  701. for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
  702. hook_result = pre_hook(self, state_dict)
  703. if hook_result is not None:
  704. state_dict = hook_result
  705. # Validate the state_dict
  706. groups = self.param_groups
  707. # Deepcopy as we write into saved_groups later to update state
  708. saved_groups = deepcopy(state_dict["param_groups"])
  709. if len(groups) != len(saved_groups):
  710. raise ValueError(
  711. "loaded state dict has a different number of " "parameter groups"
  712. )
  713. param_lens = (len(g["params"]) for g in groups)
  714. saved_lens = (len(g["params"]) for g in saved_groups)
  715. if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
  716. raise ValueError(
  717. "loaded state dict contains a parameter group "
  718. "that doesn't match the size of optimizer's group"
  719. )
  720. # Update the state
  721. id_map = dict(
  722. zip(
  723. chain.from_iterable(g["params"] for g in saved_groups),
  724. chain.from_iterable(g["params"] for g in groups),
  725. )
  726. )
  727. def _cast(param, value, param_id=None, param_groups=None, key=None):
  728. r"""Make a deep copy of value, casting all tensors to device of param."""
  729. if isinstance(value, torch.Tensor):
  730. return Optimizer._process_value_according_to_param_policy(
  731. param, value, param_id, param_groups, key
  732. )
  733. elif isinstance(value, dict):
  734. return {
  735. k: _cast(
  736. param, v, param_id=param_id, param_groups=param_groups, key=k
  737. )
  738. for k, v in value.items()
  739. }
  740. elif isinstance(value, Iterable):
  741. return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
  742. else:
  743. return value
  744. # Copy state assigned to params (and cast tensors to appropriate types).
  745. # State that is not assigned to params is copied as is (needed for
  746. # backward compatibility).
  747. state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
  748. for k, v in state_dict["state"].items():
  749. if k in id_map:
  750. param = id_map[k]
  751. state[param] = _cast(
  752. param, v, param_id=k, param_groups=state_dict["param_groups"]
  753. )
  754. else:
  755. state[k] = v
  756. # Update parameter groups, setting their 'params' value
  757. def update_group(
  758. group: Dict[str, Any], new_group: Dict[str, Any]
  759. ) -> Dict[str, Any]:
  760. new_group["params"] = group["params"]
  761. return new_group
  762. param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
  763. self.__setstate__({"state": state, "param_groups": param_groups})
  764. for post_hook in self._optimizer_load_state_dict_post_hooks.values():
  765. post_hook(self)
  766. @torch._disable_dynamo
  767. def zero_grad(self, set_to_none: bool = True) -> None:
  768. r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
  769. Args:
  770. set_to_none (bool): instead of setting to zero, set the grads to None.
  771. This will in general have lower memory footprint, and can modestly improve performance.
  772. However, it changes certain behaviors. For example:
  773. 1. When the user tries to access a gradient and perform manual ops on it,
  774. a None attribute or a Tensor full of 0s will behave differently.
  775. 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
  776. are guaranteed to be None for params that did not receive a gradient.
  777. 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
  778. (in one case it does the step with a gradient of 0 and in the other it skips
  779. the step altogether).
  780. """
  781. foreach = self.defaults.get("foreach", False) or self.defaults.get(
  782. "fused", False
  783. )
  784. if not hasattr(self, "_zero_grad_profile_name"):
  785. self._patch_step_function()
  786. per_device_and_dtype_grads: Optional[
  787. DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]
  788. ]
  789. if foreach:
  790. per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
  791. else:
  792. per_device_and_dtype_grads = None
  793. with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
  794. for group in self.param_groups:
  795. for p in group["params"]:
  796. if p.grad is not None:
  797. if set_to_none:
  798. p.grad = None
  799. else:
  800. if p.grad.grad_fn is not None:
  801. p.grad.detach_()
  802. else:
  803. p.grad.requires_grad_(False)
  804. if not foreach or p.grad.is_sparse:
  805. p.grad.zero_()
  806. else:
  807. assert per_device_and_dtype_grads is not None
  808. per_device_and_dtype_grads[p.grad.device][
  809. p.grad.dtype
  810. ].append(p.grad)
  811. if foreach:
  812. assert per_device_and_dtype_grads is not None
  813. for per_dtype_grads in per_device_and_dtype_grads.values():
  814. for grads in per_dtype_grads.values():
  815. torch._foreach_zero_(grads)
  816. @overload
  817. def step(self, closure: None = ...) -> None:
  818. ...
  819. @overload
  820. def step(self, closure: Callable[[], float]) -> float:
  821. ...
  822. def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
  823. r"""Performs a single optimization step (parameter update).
  824. Args:
  825. closure (Callable): A closure that reevaluates the model and
  826. returns the loss. Optional for most optimizers.
  827. .. note::
  828. Unless otherwise specified, this function should not modify the
  829. ``.grad`` field of the parameters.
  830. """
  831. raise NotImplementedError
  832. @torch._disable_dynamo
  833. def add_param_group(self, param_group: Dict[str, Any]) -> None:
  834. r"""Add a param group to the :class:`Optimizer` s `param_groups`.
  835. This can be useful when fine tuning a pre-trained network as frozen layers can be made
  836. trainable and added to the :class:`Optimizer` as training progresses.
  837. Args:
  838. param_group (dict): Specifies what Tensors should be optimized along with group
  839. specific optimization options.
  840. """
  841. if not isinstance(param_group, dict):
  842. raise TypeError(f"param_group must be a dict, but got {type(param_group)}")
  843. params = param_group["params"]
  844. if isinstance(params, torch.Tensor):
  845. param_group["params"] = [params]
  846. elif isinstance(params, set):
  847. raise TypeError(
  848. "optimizer parameters need to be organized in ordered collections, but "
  849. "the ordering of tensors in sets will change between runs. Please use a list instead."
  850. )
  851. else:
  852. param_group["params"] = list(params)
  853. for param in param_group["params"]:
  854. if not isinstance(param, torch.Tensor):
  855. raise TypeError(
  856. "optimizer can only optimize Tensors, "
  857. "but one of the params is " + torch.typename(param)
  858. )
  859. if not self.defaults.get("differentiable", None) and not (
  860. param.is_leaf or param.retains_grad
  861. ):
  862. raise ValueError("can't optimize a non-leaf Tensor")
  863. for name, default in self.defaults.items():
  864. if default is required and name not in param_group:
  865. raise ValueError(
  866. f"parameter group didn't specify a value of required optimization parameter {name}"
  867. )
  868. else:
  869. param_group.setdefault(name, default)
  870. params = param_group["params"]
  871. if len(params) != len(set(params)):
  872. warnings.warn(
  873. "optimizer contains a parameter group with duplicate parameters; "
  874. "in future, this will cause an error; "
  875. "see github.com/pytorch/pytorch/issues/40967 for more information",
  876. stacklevel=3,
  877. )
  878. param_set: Set[torch.Tensor] = set()
  879. for group in self.param_groups:
  880. param_set.update(set(group["params"]))
  881. if not param_set.isdisjoint(set(param_group["params"])):
  882. raise ValueError("some parameters appear in more than one parameter group")
  883. self.param_groups.append(param_group)