named_optimizer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import warnings
  4. from copy import deepcopy
  5. from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload
  6. import torch
  7. import torch.nn as nn
  8. from torch import optim
  9. from torch.distributed._shard.sharded_tensor import ShardedTensor
  10. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  11. __all__: List[str] = []
  12. logger = logging.getLogger(__name__)
  13. class _NamedOptimizer(optim.Optimizer):
  14. """
  15. ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.
  16. We replace the original key (number) in an optim to the
  17. fully qualified name (FQN) string. User can initialize the optim as they
  18. initialize a PyTorch optim, the only difference is that they also need to
  19. pass in the FQN of each parameters.
  20. Args:
  21. named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
  22. Mapping from FQN to parameter.
  23. optimizer_class (optim.Optimizer):
  24. The class of optimizer to instantiate.
  25. param_groups (Collection[Mapping[str, Any]]):
  26. `param_groups` to pass to optimizer if specified.
  27. The key of the inner map needs to be FQNs.
  28. Default: None
  29. module (nn.Module): the module whose parameters to updated
  30. by the optimizer.
  31. args: arguments to pass to the optimizer constructor.
  32. kwargs: arguments to pass to the optimizer constructor.
  33. Example::
  34. >>> # xdoctest: +SKIP("distributed")
  35. >>> from torch import optim
  36. >>> from torch.distributed.optim import _NamedOptimizer
  37. >>>
  38. >>> # Define the named optimizer.
  39. >>> m = Model(...)
  40. >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
  41. >>> # Forward pass + backward pass.
  42. >>> named_optim.step()
  43. >>> ...
  44. >>> # Call state_dict for the named optimizer returns a FQN state_dict.
  45. >>> named_optim.state_dict()
  46. Warning: This API is still in development and subject to change.
  47. TODO: Add tutorial for _NamedOptimizer.
  48. TODO: Add documentation in the docstring for the public attributes
  49. like self.param_groups and self.named_parameters.
  50. """
  51. def __init__(
  52. self,
  53. named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
  54. optimizer_class: optim.Optimizer,
  55. param_groups: Optional[Collection[Mapping[str, Any]]] = None,
  56. module: Optional[nn.Module] = None,
  57. *args,
  58. **kwargs,
  59. ) -> None:
  60. torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
  61. self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
  62. self._param_groups_check()
  63. self.named_parameters = dict(named_parameters)
  64. params_for_optimizer = (
  65. self.named_parameters.values() if param_groups is None else param_groups
  66. )
  67. self._optimizer = optimizer_class( # type: ignore[operator]
  68. params_for_optimizer,
  69. *args,
  70. **kwargs,
  71. )
  72. self.module = module
  73. if param_groups is None:
  74. self.ordered_param_keys = list(self.named_parameters.keys())
  75. else:
  76. warnings.warn(
  77. "Since we pass in param_groups, we will use param_groups to "
  78. "initialize the optimizer, not all parameters of the module."
  79. )
  80. param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
  81. ordered_param_keys = []
  82. for group in param_groups:
  83. for param in group["params"]:
  84. if param not in param_to_key:
  85. raise ValueError(
  86. f"Expect param name {param} found in param group but is missing."
  87. )
  88. ordered_param_keys.append(param_to_key[param])
  89. self.ordered_param_keys = ordered_param_keys
  90. # Update param_groups from optimizer.
  91. self.param_groups = self._optimizer.param_groups
  92. def _param_groups_check(self):
  93. if self.param_groups is not None:
  94. for param_group in self.param_groups:
  95. assert isinstance(param_group, dict), "param group must be a dict"
  96. assert "params" in param_group, "param group must contain key params"
  97. params = param_group["params"]
  98. if isinstance(params, torch.Tensor):
  99. params = [params]
  100. params = list(params)
  101. for param in params:
  102. if not isinstance(param, torch.Tensor):
  103. raise TypeError(
  104. "optimizer can only optimize Tensors, "
  105. "but one of the params is " + torch.typename(param)
  106. )
  107. param_group["params"] = params
  108. def state_dict(self) -> Dict[str, Any]:
  109. """
  110. Return the ``state_dict`` of the optimizer.
  111. Instead of using number to index
  112. parameters, we will use module fully qualified name (FQN) as the key.
  113. """
  114. state_dict = self._optimizer.state_dict()
  115. param_groups = state_dict["param_groups"]
  116. ret_state = {
  117. self.ordered_param_keys[st_key]: state_val
  118. for st_key, state_val in state_dict["state"].items()
  119. }
  120. ret_groups = []
  121. for group in param_groups:
  122. param_keys = []
  123. for param in group["params"]:
  124. param_keys.append(self.ordered_param_keys[param])
  125. ret_group = {"params": sorted(param_keys)}
  126. for k, v in group.items():
  127. if k != "params":
  128. ret_group[k] = deepcopy(v)
  129. ret_groups.append(ret_group)
  130. return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
  131. @overload
  132. def step(self, closure: None = ...) -> None:
  133. ...
  134. @overload
  135. def step(self, closure: Callable[[], float]) -> float:
  136. ...
  137. def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
  138. """
  139. Perform a single optimization step.
  140. This will call :meth:`torch.optim.Optimizer.step` on the wrapped
  141. optimizer.
  142. """
  143. return self._optimizer.step(closure=closure)
  144. @property
  145. def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
  146. return self._optimizer.state
  147. def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
  148. """
  149. Define the default behavior to load a state_dict for ``_NamedOptimizer``.
  150. Sample Code
  151. ```
  152. my_model = MyModule()
  153. optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
  154. ...
  155. optim_state_dict = optimizer.state_dict()
  156. ...
  157. ...
  158. optimizer.load_state_dict(optim_state_dict)
  159. ...
  160. ```
  161. Args:
  162. state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
  163. Note that this state dict update is performed in place.
  164. .. note:: PyTorch is using lazy init to initialize the optim states.
  165. So it is possible that there is no optim state when user call
  166. ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
  167. that users can only call ``load_state_dict`` after the state is initialized.
  168. By doing this, we can validate the optim ``state_dict`` to be loaded.
  169. """
  170. new_state_dict = self._optimizer.state_dict()
  171. state_dict = self._pre_load_state_dict(state_dict)
  172. state = state_dict["state"]
  173. new_state = new_state_dict["state"]
  174. if len(new_state) == 0:
  175. raise ValueError(
  176. "Expects the optim to be initialized before load but found not initialized."
  177. )
  178. for idx, param_key in enumerate(self.ordered_param_keys):
  179. # When the conditional training is performed, not all parameters are updated in the optim.
  180. if param_key not in state.keys():
  181. continue
  182. if len(state[param_key]) != len(new_state[idx]):
  183. raise ValueError(
  184. f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
  185. )
  186. # Iterate through all optimizer states.
  187. for state_key, state_val in new_state[idx].items():
  188. if state_key not in state[param_key]:
  189. raise ValueError(
  190. f"Expects state {state_key} for parameter {param_key} but not found."
  191. )
  192. src_state_val = state[param_key][state_key]
  193. if isinstance(state_val, ShardedTensor):
  194. assert isinstance(src_state_val, ShardedTensor)
  195. num_shards = len(state_val.local_shards())
  196. num_new_shards = len(src_state_val.local_shards())
  197. if num_shards != num_new_shards:
  198. raise ValueError(
  199. f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
  200. )
  201. for shard, src_shard in zip(
  202. state_val.local_shards(), src_state_val.local_shards()
  203. ):
  204. shard.tensor.detach().copy_(src_shard.tensor)
  205. elif isinstance(state_val, torch.Tensor):
  206. assert isinstance(src_state_val, torch.Tensor)
  207. state_val.detach().copy_(src_state_val)
  208. else:
  209. new_state[idx][state_key] = deepcopy(src_state_val)
  210. # Load param_groups of state_dict
  211. src_param_groups = state_dict["param_groups"]
  212. new_param_groups = new_state_dict["param_groups"]
  213. src_group_map = {}
  214. for group in src_param_groups:
  215. param_keys = list(group["params"])
  216. src_group_map[_gen_param_group_key(param_keys)] = group
  217. new_group_map = {}
  218. for new_group in new_param_groups:
  219. param_keys = []
  220. for param_key in new_group["params"]:
  221. param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload]
  222. new_group_map[_gen_param_group_key(param_keys)] = new_group
  223. for group_key, new_group in new_group_map.items():
  224. # When not all parameters are used in training or receive gradient, aka., not all parameters
  225. # would be in the param_group. Thus we skip the group_key here.
  226. if group_key not in src_group_map:
  227. continue
  228. src_group = src_group_map[group_key]
  229. if len(src_group) != len(new_group):
  230. raise ValueError(
  231. f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
  232. )
  233. for k in src_group:
  234. if k not in new_group:
  235. raise ValueError(
  236. f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
  237. )
  238. if k != "params":
  239. new_group[k] = deepcopy(src_group[k])
  240. self._optimizer.load_state_dict(new_state_dict)
  241. def add_param_group(self, param_group: Mapping[str, Any]) -> None:
  242. """
  243. Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
  244. Warning: This API is still in development and subject to change.
  245. """
  246. assert isinstance(param_group, dict), "param group must be a dict"
  247. params = param_group["params"]
  248. if isinstance(params, torch.Tensor):
  249. param_group["params"] = [params]
  250. else:
  251. param_group["params"] = list(params)
  252. param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
  253. for param in param_group["params"]:
  254. if param not in param_to_key:
  255. raise ValueError("some parameters are not in the module")
  256. self.ordered_param_keys.append(param_to_key[param])
  257. self._optimizer.add_param_group(param_group)
  258. # Update param_groups from optimizer.
  259. self.param_groups = self._optimizer.param_groups
  260. def init_state(self) -> None:
  261. """
  262. Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.
  263. This allows doing in-place loading of optimizer state from a checkpoint.
  264. """
  265. for param in self.named_parameters.values():
  266. if param.requires_grad:
  267. t = torch.zeros_like(param)
  268. param.grad = torch.autograd.Variable(t)
  269. # Calling ``step`` will load the initial state for optimizer states.
  270. self.step(closure=None)
  271. def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
  272. # TODO(chienchin): This API should be FSDP agnostic and should support
  273. # general user hooks.
  274. if isinstance(self.module, FSDP):
  275. return FSDP.optim_state_dict_to_load(
  276. self.module, self._optimizer, state_dict, is_named_optimizer=True
  277. )
  278. return state_dict
  279. def _post_state_dict(self, state_dict) -> Dict[str, Any]:
  280. # TODO(chienchin): This API should be FSDP agnostic and should support
  281. # general user hooks.
  282. if isinstance(self.module, FSDP):
  283. FSDP.optim_state_dict(self.module, self._optimizer, state_dict)
  284. return state_dict
  285. def _gen_param_group_key(param_keys: List[str]) -> str:
  286. """Concatenate all param keys as a unique indentifier for one param group."""
  287. return "/".join(sorted(param_keys))