optimizer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # mypy: ignore-errors
  2. import weakref
  3. from typing import Dict, List, TYPE_CHECKING
  4. import torch
  5. from torch.utils._pytree import tree_map_only
  6. from ..guards import GuardBuilder, install_guard
  7. from ..source import (
  8. AttrSource,
  9. ConstDictKeySource,
  10. GetItemSource,
  11. GlobalWeakRefSource,
  12. GradSource,
  13. )
  14. from ..utils import GLOBAL_KEY_PREFIX
  15. from .constant import ConstantVariable
  16. from .dicts import ConstDictVariable
  17. from .lists import ListVariable
  18. from .misc import GetAttrVariable
  19. from .user_defined import UserDefinedObjectVariable
  20. if TYPE_CHECKING:
  21. from .base import VariableTracker
  22. class ArgMappingException(Exception):
  23. pass
  24. class GuardInstallException(Exception):
  25. pass
  26. class OptimizerVariable(UserDefinedObjectVariable):
  27. _nonvar_fields = {
  28. "grad_to_source",
  29. "tensor_to_source",
  30. "static_tensor_names",
  31. *UserDefinedObjectVariable._nonvar_fields,
  32. }
  33. def __init__(
  34. self,
  35. value,
  36. grad_to_source=None,
  37. static_tensor_names=None,
  38. tensor_to_source=None,
  39. **kwargs,
  40. ):
  41. super().__init__(value, **kwargs)
  42. self.grad_to_source = grad_to_source or {}
  43. self.tensor_to_source = tensor_to_source or {}
  44. self.static_tensor_names = static_tensor_names or set()
  45. def call_method(
  46. self,
  47. tx,
  48. name,
  49. args: "List[VariableTracker]",
  50. kwargs: "Dict[str, VariableTracker]",
  51. ) -> "VariableTracker":
  52. """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
  53. if name == "_init_group":
  54. try:
  55. self.graph_break_if_pending_mutation(tx)
  56. self.move_step_if_cpu()
  57. py_args, py_kwargs = self.get_python_args(*args, **kwargs)
  58. ret_val = self.value._init_group(*py_args, **py_kwargs)
  59. self.map_sources_and_install_guards(tx)
  60. self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
  61. # stash a weak_ptr to optimizer to invalidate code
  62. # if the optimizer object dies
  63. mangled_name = f"__optimizer_{id(self.value)}"
  64. tx.store_global_weakref_by_id(mangled_name, self.value)
  65. self.create_finalizer(tx)
  66. # This is currently safe only because the only actual `ret_val`s returned
  67. # by the `_init_group` of existing optimizers are properties that are invariant
  68. # to the input tensors (e.g. dtype, layout). Changing these would trigger a
  69. # recompilation and hence never result in the wrong specialization of `ret_val`.
  70. return ConstantVariable.create(ret_val)
  71. except (ArgMappingException, GuardInstallException) as _:
  72. # trace normally if we can't map args or install guards correctly
  73. pass
  74. return super().call_method(tx, name, args, kwargs)
  75. def var_getattr(self, tx, name):
  76. # Note: this allows us to intercept the call in call_method
  77. # in the typical case, we return a UserMethodVariable
  78. # which will directly inline
  79. if name in ("_init_group", "step"):
  80. return GetAttrVariable(self, name, source=AttrSource(self.source, name))
  81. if name == "param_groups":
  82. from ..decorators import mark_static_address
  83. for group in self.value.param_groups:
  84. for p in group["params"]:
  85. mark_static_address(p)
  86. self._set_capturable(tx)
  87. return super().var_getattr(tx, name)
  88. def graph_break_if_pending_mutation(self, tx):
  89. # If there are pending mutations on a parameter (due to using closure)
  90. # then we need to graph break to allow the python version of the parameter
  91. # to update, so that running _init_group will initialize the states with
  92. # the correct values
  93. for g in self.value.param_groups:
  94. for p in g["params"]:
  95. side_effects = tx.output.side_effects
  96. variable = side_effects.id_to_variable.get(id(p), None)
  97. if variable and side_effects.has_pending_mutation(variable):
  98. from ..exc import Unsupported
  99. raise Unsupported("Pending mutation on parameter")
  100. def _set_capturable(self, tx):
  101. from . import LazyVariableTracker
  102. from .builder import VariableBuilder
  103. # We only set capturable if params are on cuda
  104. # and the state is not initialized
  105. def safe_to_set_capturable(group):
  106. all_uninitialized = True
  107. all_cuda = True
  108. for p in group.get("params", list()):
  109. all_cuda &= p.is_cuda
  110. all_uninitialized &= p not in self.value.state
  111. return "capturable" in group and all_uninitialized and all_cuda
  112. # track indices to not set so we don't need to
  113. # in the variable tracker realize the whole state
  114. # we handle guarding the state specially
  115. for ind, group in enumerate(self.value.param_groups):
  116. if safe_to_set_capturable(group):
  117. group["capturable"] = True
  118. param_groups_vt = LazyVariableTracker.realize_all(
  119. VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
  120. self.value.param_groups
  121. )
  122. )
  123. for ind, param_group_vt in enumerate(param_groups_vt.items):
  124. key = ConstDictVariable._HashableTracker(
  125. ConstantVariable.create("capturable")
  126. )
  127. param_group_vt.items[key] = ConstantVariable.create(True)
  128. def get_python_args(self, *args, **kwargs):
  129. """Get python values equivalent to the variable tracker args"""
  130. def map_arg(arg):
  131. if isinstance(arg, ConstantVariable):
  132. return arg.as_python_constant()
  133. elif isinstance(arg, ListVariable) and not arg.items:
  134. return []
  135. elif (
  136. isinstance(arg, ConstDictVariable)
  137. and isinstance(arg.source, GetItemSource)
  138. and isinstance(arg.source.base, AttrSource)
  139. and arg.source.base.member == "param_groups"
  140. ):
  141. return self.value.param_groups[arg.source.index]
  142. raise ArgMappingException
  143. new_args = [map_arg(arg) for arg in args]
  144. new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
  145. return new_args, new_kwargs
  146. # If users load an old state dictionary,
  147. # it's possible that step could be on the cpu
  148. # if this is the case, move it to the GPU
  149. # corresponding to the parameter
  150. # in most cases this is a no-op because the state is empty
  151. def move_step_if_cpu(self):
  152. for p, state in self.value.state.items():
  153. if "step" in state and state["step"].is_cpu:
  154. state["step"] = state["step"].to(p.device)
  155. def map_sources_and_install_guards(self, tx):
  156. from ..decorators import mark_static_address
  157. from .builder import VariableBuilder
  158. from .lazy import LazyVariableTracker
  159. self.grad_to_source = {}
  160. self.tensor_to_source = {}
  161. # Tracing the _init_group is expensive. But we still have to insert the
  162. # necessary guards for _init_group. So, we manually handle insertion of
  163. # guards. We also want to mark all the tensors inside the state dict to
  164. # be static address.
  165. # Mark all the tensors in the state dict to be static address. This has
  166. # to be done first because the variable builder relies on the static
  167. # address annotation.
  168. def mark_static(x):
  169. mark_static_address(x)
  170. tree_map_only(torch.Tensor, mark_static, self.value.state)
  171. # Recursively realize the variable trackers for optim.state and
  172. # optim.param_groups, which recursively install the necessary guards.
  173. param_groups_vt = LazyVariableTracker.realize_all(
  174. VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
  175. self.value.param_groups
  176. )
  177. )
  178. state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
  179. self.value.state
  180. )
  181. # We need to realize the top level state dict to populate
  182. # the guard locals
  183. state_vt.realize()
  184. # Populate self.grad_to_source and self.tensor_to_source so that we can
  185. # manually update_list_args
  186. for g_ind, (group, group_vt) in enumerate(
  187. zip(self.value.param_groups, param_groups_vt.items)
  188. ):
  189. # we assume here that all params within a param group
  190. # are initialized similarly
  191. if len(group["params"]) > 0:
  192. for param in group["params"]:
  193. if param.grad is not None:
  194. key_index = None
  195. for i, k in enumerate(self.value.state.keys()):
  196. if k is param:
  197. key_index = i
  198. break
  199. if key_index:
  200. state_source = AttrSource(self.source, "state")
  201. LazyVariableTracker.realize_all(
  202. VariableBuilder(
  203. tx,
  204. GetItemSource(
  205. state_source,
  206. ConstDictKeySource(state_source, key_index),
  207. ),
  208. )(self.value.state[param])
  209. )
  210. break
  211. group_source = group_vt.source
  212. params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
  213. for p_ind, (p, p_vt) in enumerate(
  214. zip(group["params"], params_vt.unpack_var_sequence(tx))
  215. ):
  216. param_source = p_vt.source
  217. self.tensor_to_source[p] = param_source
  218. grad_source = GradSource(
  219. param_source,
  220. "grad",
  221. )
  222. if p.grad is not None:
  223. self.grad_to_source[p.grad] = grad_source
  224. else:
  225. install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
  226. # We have to again iterate over the state dict to collect the
  227. # tensor_to_source dict. This is used for the finalizer.
  228. state_source = AttrSource(self.source, "state")
  229. for idx, (p, value) in enumerate(self.value.state.items()):
  230. p_state_source = GetItemSource(
  231. state_source, ConstDictKeySource(state_source, idx)
  232. )
  233. for k, v in value.items():
  234. if (
  235. isinstance(v, torch.Tensor)
  236. and v not in self.grad_to_source
  237. and v not in self.tensor_to_source
  238. ):
  239. self.tensor_to_source[v] = GetItemSource(p_state_source, k)
  240. def wrap_tensor(self, tx, tensor_value):
  241. """Wrap state tensor in a TensorVariable"""
  242. from ..decorators import mark_static_address
  243. from .builder import VariableBuilder
  244. # If we have a source for a tensor already use it,
  245. # if we have not seen a tensor before, stash and use a
  246. # global weak ref source, since it must be an optimizer tensor
  247. # that we have missed
  248. if tensor_value in self.tensor_to_source:
  249. # mark these tensors as static for cudagraphs
  250. mark_static_address(tensor_value)
  251. builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
  252. self.static_tensor_names.add(tx.output.module_key_name(builder.name))
  253. elif tensor_value in self.grad_to_source:
  254. builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
  255. else:
  256. # mark these tensors as static for cudagraphs
  257. mark_static_address(tensor_value)
  258. global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
  259. builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
  260. self.static_tensor_names.add(tx.output.module_key_name(builder.name))
  261. result = builder(tensor_value)
  262. return result
  263. def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
  264. """Update the args and kwargs to the traced optimizer call"""
  265. for arg, py_arg in zip(args, py_args):
  266. if isinstance(arg, ListVariable):
  267. assert isinstance(
  268. py_arg, list
  269. ), "py_arg should be a list in optimizer variable"
  270. for i, val in enumerate(py_arg):
  271. tx.output.side_effects.mutation(arg)
  272. if isinstance(val, torch.Tensor):
  273. arg.items.append(self.wrap_tensor(tx, val))
  274. else:
  275. from .builder import SourcelessBuilder, VariableBuilder
  276. if arg.source:
  277. arg.items.append(
  278. VariableBuilder(tx, GetItemSource(arg.source, i))(val)
  279. )
  280. else:
  281. arg.items.append(SourcelessBuilder.create(tx, val))
  282. def create_finalizer(self, tx):
  283. names_to_delete = self.static_tensor_names
  284. value = self.value
  285. tc = tx.output.tracing_context
  286. def init_finalizer(gm):
  287. def clear_static_tensor_refs():
  288. for name in names_to_delete:
  289. gm._buffers.pop(name, None)
  290. gm._parameters.pop(name, None)
  291. if tc.params_flat:
  292. tc.params_flat.clear()
  293. weakref.finalize(value, clear_static_tensor_refs)
  294. tx.output.add_graph_finalizer(init_finalizer)