side_effects.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import warnings
  4. from typing import Any, Dict, List, Optional, Union
  5. import torch.nn
  6. from . import utils, variables
  7. from .bytecode_transformation import (
  8. create_call_function,
  9. create_call_method,
  10. create_instruction,
  11. create_load_method,
  12. )
  13. from .codegen import PyCodegen
  14. from .exc import unimplemented
  15. from .source import GlobalSource, LocalSource, Source
  16. from .utils import nn_module_new, object_new
  17. from .variables.base import (
  18. is_side_effect_safe,
  19. MutableLocalBase,
  20. MutableLocalSource,
  21. VariableTracker,
  22. )
  23. class MutableSideEffects(MutableLocalBase):
  24. """
  25. VariableTracker.mutable_local marker to indicate a list passed as
  26. an input that if we mutate we need to re-apply those mutations after
  27. the graph runs.
  28. """
  29. def __init__(self, source: Source, is_modified: bool = False):
  30. super().__init__(MutableLocalSource.Existing)
  31. self.source = source
  32. self.is_modified = is_modified
  33. class AttributeMutation(MutableLocalBase):
  34. """
  35. VariableTracker.mutable_local marker to track changes to attributes
  36. """
  37. def __init__(self, typ: MutableLocalSource, source: Optional[Source]):
  38. super().__init__(typ)
  39. self.source = source
  40. class AttributeMutationExisting(AttributeMutation):
  41. def __init__(self, source: Source):
  42. super().__init__(MutableLocalSource.Existing, source)
  43. self.source = source
  44. class AttributeMutationNew(AttributeMutation):
  45. def __init__(self, source: Optional[Source], cls_source: Optional[Source]):
  46. super().__init__(MutableLocalSource.Local, source)
  47. self.cls_source = cls_source
  48. class SideEffects:
  49. """
  50. Track side effects (list mutation, setattr, etc) that need to be
  51. applied after an FX graph is run.
  52. """
  53. id_to_variable: Dict[int, VariableTracker]
  54. store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]]
  55. keepalive: List[Any]
  56. def __init__(
  57. self,
  58. id_to_variable=None,
  59. store_attr_mutations=None,
  60. keepalive=None,
  61. save_for_backward=None,
  62. tensor_hooks=None,
  63. ):
  64. super().__init__()
  65. self.id_to_variable = id_to_variable or {}
  66. self.store_attr_mutations = store_attr_mutations or {}
  67. self.keepalive = keepalive or []
  68. self.save_for_backward = save_for_backward or []
  69. self.tensor_hooks = tensor_hooks or {}
  70. def __eq__(self, other: object) -> bool:
  71. assert isinstance(other, SideEffects)
  72. # NB: do NOT test keepalive
  73. return (
  74. self.id_to_variable == other.id_to_variable
  75. and self.store_attr_mutations == other.store_attr_mutations
  76. and self.save_for_backward == other.save_for_backward
  77. and self.tensor_hooks == other.tensor_hooks
  78. )
  79. def diff(self, other: "SideEffects") -> Optional[str]:
  80. if self.id_to_variable != other.id_to_variable:
  81. sk_itv = self.id_to_variable.keys()
  82. ok_itv = other.id_to_variable.keys()
  83. if sk_itv != ok_itv:
  84. return f"id_to_variable keys: {sk_itv} != {ok_itv}"
  85. # Feel free to augment this with more fancy diffing logic
  86. # if needed for debugging
  87. return "id_to_variable: unknown diff"
  88. elif self.store_attr_mutations != other.store_attr_mutations:
  89. sk_sam = self.store_attr_mutations.keys()
  90. ok_sam = other.store_attr_mutations.keys()
  91. if sk_sam != ok_sam:
  92. return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
  93. return "store_attr_mutations: unknown diff"
  94. elif self.save_for_backward != other.save_for_backward:
  95. return "save_for_backward"
  96. elif self.tensor_hooks != other.tensor_hooks:
  97. return "tensor_hooks"
  98. else:
  99. return None
  100. def clone(self):
  101. """Create a shallow copy"""
  102. return self.__class__(
  103. id_to_variable=dict(self.id_to_variable),
  104. store_attr_mutations={
  105. k: dict(v) for k, v in self.store_attr_mutations.items()
  106. },
  107. keepalive=list(self.keepalive),
  108. save_for_backward=self.save_for_backward,
  109. tensor_hooks=self.tensor_hooks,
  110. )
  111. def __contains__(self, item):
  112. return id(item) in self.id_to_variable
  113. def __getitem__(self, item):
  114. return self.id_to_variable[id(item)]
  115. def check_allowed_side_effect(self, item):
  116. from torch._dynamo.variables.misc import AutogradFunctionContextVariable
  117. # People do things like self.dim = dim inside autograd.Function.
  118. # These are benign.
  119. if isinstance(item, AutogradFunctionContextVariable):
  120. return True
  121. if not is_side_effect_safe(item.mutable_local):
  122. unimplemented(
  123. "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
  124. )
  125. def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
  126. assert self.is_attribute_mutation(item)
  127. self.check_allowed_side_effect(item)
  128. if item.mutable_local not in self.store_attr_mutations:
  129. self.store_attr_mutations[item.mutable_local] = {}
  130. self.store_attr_mutations[item.mutable_local][name] = value
  131. def load_attr(self, item, name, deleted_ok=False):
  132. assert self.is_attribute_mutation(item)
  133. result = self.store_attr_mutations[item.mutable_local][name]
  134. if not deleted_ok and isinstance(result, variables.DeletedVariable):
  135. unimplemented("read deleted attribute")
  136. return result
  137. def store_cell(self, cellvar, value):
  138. assert isinstance(cellvar, variables.NewCellVariable)
  139. assert isinstance(value, variables.VariableTracker)
  140. self.store_attr(cellvar, "cell_contents", value)
  141. def load_cell(self, cellvar):
  142. assert isinstance(cellvar, variables.NewCellVariable)
  143. return self.load_attr(cellvar, "cell_contents")
  144. def load_global(self, gvar: VariableTracker, name: str):
  145. assert isinstance(gvar, variables.VariableTracker)
  146. return self.load_attr(gvar, name)
  147. def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
  148. assert isinstance(gvar, variables.VariableTracker)
  149. assert isinstance(value, variables.VariableTracker)
  150. self.store_attr(gvar, name, value)
  151. @staticmethod
  152. def cls_supports_mutation_side_effects(cls):
  153. return (
  154. inspect.getattr_static(cls, "__getattribute__", None)
  155. is object.__getattribute__
  156. )
  157. def is_attribute_mutation(self, item):
  158. return isinstance(item.mutable_local, AttributeMutation)
  159. def has_pending_mutation(self, item):
  160. return self.is_attribute_mutation(item) and bool(
  161. self.store_attr_mutations.get(item.mutable_local)
  162. )
  163. def has_pending_mutation_of_attr(self, item, name):
  164. return self.is_attribute_mutation(
  165. item
  166. ) and name in self.store_attr_mutations.get(item.mutable_local, ())
  167. def is_modified(self, item):
  168. if isinstance(item.mutable_local, AttributeMutationNew):
  169. return True
  170. if self.is_attribute_mutation(item):
  171. return item.mutable_local in self.store_attr_mutations
  172. return item.mutable_local.is_modified
  173. def _track_obj(
  174. self,
  175. item: Any,
  176. variable: VariableTracker,
  177. mutable_cls=MutableSideEffects,
  178. ):
  179. """Start tracking a new variable for mutation"""
  180. assert variable.source is not None
  181. variable.mutable_local = mutable_cls(variable.source)
  182. self.id_to_variable[id(item)] = variable
  183. self.keepalive.append(item)
  184. return variable
  185. track_mutable = _track_obj
  186. def track_object_existing(
  187. self,
  188. item: Any,
  189. variable: VariableTracker,
  190. ):
  191. return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting)
  192. def track_object_new(
  193. self,
  194. cls_source: Source,
  195. user_cls: Any,
  196. variable_cls: Any,
  197. options,
  198. ):
  199. if user_cls is torch.autograd.function.FunctionCtx:
  200. with warnings.catch_warnings(record=True):
  201. obj = torch.autograd.Function()
  202. elif issubclass(user_cls, torch.nn.Module):
  203. obj = nn_module_new(user_cls)
  204. else:
  205. obj = object_new(user_cls)
  206. variable = variable_cls(
  207. obj,
  208. mutable_local=AttributeMutationNew(None, cls_source),
  209. **options,
  210. )
  211. self.id_to_variable[id(obj)] = variable
  212. self.keepalive.append(obj)
  213. return variable
  214. def track_cell_new(
  215. self,
  216. ):
  217. obj = object()
  218. variable = variables.NewCellVariable(
  219. mutable_local=AttributeMutationNew(None, None),
  220. )
  221. self.id_to_variable[id(obj)] = variable
  222. self.keepalive.append(obj)
  223. return variable
  224. def track_cell_existing(self, source: Source, item: Any):
  225. variable = variables.NewCellVariable(
  226. mutable_local=AttributeMutationExisting(source),
  227. )
  228. self.id_to_variable[id(item)] = variable
  229. self.keepalive.append(item)
  230. return variable
  231. def track_global_existing(self, source: Source, item: Any):
  232. variable = variables.NewGlobalVariable(
  233. mutable_local=AttributeMutationExisting(source),
  234. )
  235. self.id_to_variable[id(item)] = variable
  236. self.keepalive.append(item)
  237. return variable
  238. def track_save_for_backward(self, ctx, args):
  239. assert isinstance(ctx, variables.AutogradFunctionContextVariable)
  240. self.save_for_backward.append((ctx, args))
  241. def track_tensor_variables_from_runahead_side_effects(self, other):
  242. # In higher order ops we want to keep track of tensors seen in the
  243. # speculate_subgraph so that we don't lift them again as a new input in
  244. # other speculate_subgraph or in the root tracer.
  245. for other_item in other.keepalive:
  246. other_id = id(other_item)
  247. other_variable = other.id_to_variable[other_id]
  248. if other_id not in self.id_to_variable and isinstance(
  249. other_variable, variables.TensorVariable
  250. ):
  251. self.track_object_existing(other_item, other_variable)
  252. def prune_dead_object_new(self, tx):
  253. live_new_objects = set()
  254. skip_obj = None
  255. def visit(var: VariableTracker):
  256. if (
  257. isinstance(var.mutable_local, AttributeMutationNew)
  258. and var.mutable_local is not skip_obj
  259. ):
  260. live_new_objects.add(var.mutable_local)
  261. def is_live(var: Union[MutableLocalBase, VariableTracker]):
  262. if isinstance(var, AttributeMutationNew):
  263. return var in live_new_objects
  264. if isinstance(var, VariableTracker):
  265. return is_live(var.mutable_local)
  266. return True
  267. VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals))
  268. for var in self.id_to_variable.values():
  269. if not isinstance(var.mutable_local, AttributeMutationNew):
  270. VariableTracker.visit(visit, var)
  271. for skip_obj, setattrs in self.store_attr_mutations.items():
  272. VariableTracker.visit(visit, setattrs)
  273. self.id_to_variable = {
  274. k: v for k, v in self.id_to_variable.items() if is_live(v)
  275. }
  276. self.store_attr_mutations = {
  277. k: v for k, v in self.store_attr_mutations.items() if is_live(k)
  278. }
  279. def mutation(self, var):
  280. self.check_allowed_side_effect(var)
  281. if isinstance(var.mutable_local, MutableSideEffects):
  282. var.mutable_local = MutableSideEffects(var.mutable_local.source, True)
  283. def _get_modified_vars(self):
  284. return [var for var in self.id_to_variable.values() if self.is_modified(var)]
  285. def codegen_save_tempvars(self, cg: PyCodegen):
  286. for var in self._get_modified_vars():
  287. if isinstance(
  288. var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
  289. ) and isinstance(var, variables.NewCellVariable):
  290. cg.load_import_from(utils.__name__, "make_cell")
  291. cg.extend_output(create_call_function(0, True))
  292. cg.add_cache(var)
  293. if isinstance(var.mutable_local, AttributeMutationNew):
  294. var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
  295. elif isinstance(var.mutable_local, AttributeMutationNew):
  296. if isinstance(var, variables.AutogradFunctionContextVariable):
  297. unimplemented("AutogradFunctionContextVariable escaped")
  298. if "__call_nn_module_init" in self.store_attr_mutations.get(
  299. var.mutable_local, {}
  300. ):
  301. assert isinstance(var, variables.UnspecializedNNModuleVariable)
  302. cg.load_import_from(utils.__name__, "nn_module_new")
  303. else:
  304. cg.load_import_from(utils.__name__, "object_new")
  305. cg(var.mutable_local.cls_source)
  306. cg.extend_output(create_call_function(1, True))
  307. cg.add_cache(var)
  308. var.mutable_local.source = LocalSource(cg.tempvars[var])
  309. elif var in cg.tempvars:
  310. assert cg.tempvars.get(var) is None
  311. # subsequent usage should point to the original variable
  312. cg(var.mutable_local.source)
  313. cg.add_cache(var)
  314. for ctx, args in self.save_for_backward:
  315. cg(ctx.source)
  316. cg.load_method("save_for_backward")
  317. for arg in args:
  318. cg(arg)
  319. cg.extend_output(
  320. [
  321. *create_call_method(len(args)),
  322. create_instruction("POP_TOP"),
  323. ]
  324. )
  325. def register_hook(self, tensor, hook, handle, name):
  326. assert isinstance(tensor, variables.TensorVariable)
  327. assert isinstance(hook, variables.VariableTracker)
  328. assert (
  329. isinstance(handle, variables.RemovableHandleVariable)
  330. and handle.mutable_local
  331. )
  332. assert hasattr(torch.Tensor, name)
  333. idx = len(self.tensor_hooks.keys())
  334. # duplicate index possible because of self.remove_hook()
  335. while idx in self.tensor_hooks:
  336. idx += 1
  337. self.tensor_hooks[idx] = (tensor, hook, handle, name)
  338. assert not handle.idx
  339. handle.idx = idx
  340. def remove_hook(self, idx):
  341. del self.tensor_hooks[idx]
  342. def codegen_hooks(self, cg):
  343. for (
  344. tensor,
  345. hook,
  346. handle,
  347. name,
  348. ) in self.tensor_hooks.values():
  349. # Note: [On tensor.register_hook]
  350. #
  351. # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
  352. # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
  353. #
  354. # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
  355. # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
  356. # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
  357. # tensors. Because a source indicates knowledge of this object outside the torch compile region, and
  358. # because we are running residuals firmly before .backward() can be run, it is sound to invoke
  359. # `register_hook` on a known tensor.
  360. #
  361. # For tensors without a source, we support a limited subset of hooks. Global functions only, and
  362. # compiled_autograd must be enabled or we will graph break.
  363. #
  364. # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
  365. # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
  366. # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
  367. # stack intact.
  368. #
  369. # Dynamo Tensor Hooks Workflow:
  370. # - Functions passed to register_hook are lifted globally.
  371. # - For tensors with sources:
  372. # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
  373. # - Generate the tensor.
  374. # - Issue a register_hook call on the tensor, linking to the globally stored function.
  375. # - Incorporate a handle if one was established in the eager phase.
  376. # - For tensors without sources:
  377. # - We don't generate any instructions for registering a hook.
  378. # - Handles from intermediary hooks are NYI.
  379. # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
  380. # - We then manually insert the call function above into the graph.
  381. # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
  382. assert tensor.source, "Hooks on non input tensors NYI - should not get here"
  383. cg(tensor)
  384. cg.extend_output([cg.create_load_attr(name)])
  385. cg(hook)
  386. cg.extend_output(create_call_function(1, True))
  387. # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
  388. # be associated with the return value of register_hook(). This consumes the top of stack.
  389. cg.add_cache(handle)
  390. def codegen_update_mutated(self, cg: PyCodegen):
  391. suffixes = []
  392. for var in self._get_modified_vars():
  393. if isinstance(var, variables.ListVariable):
  394. # old[:] = new
  395. cg(var, allow_cache=False)
  396. cg(var.mutable_local.source) # type: ignore[attr-defined]
  397. cg.extend_output(
  398. [
  399. cg.create_load_const(None),
  400. cg.create_load_const(None),
  401. create_instruction("BUILD_SLICE", arg=2),
  402. ]
  403. )
  404. suffixes.append([create_instruction("STORE_SUBSCR")])
  405. elif isinstance(var, variables.ConstDictVariable):
  406. cg.tx.output.update_co_names("clear")
  407. cg.tx.output.update_co_names("update")
  408. cg(var.mutable_local.source) # type: ignore[attr-defined]
  409. cg.extend_output([create_load_method("update")])
  410. cg(var, allow_cache=False)
  411. cg(var.mutable_local.source) # type: ignore[attr-defined]
  412. cg.extend_output([create_load_method("clear")])
  413. suffixes.append(
  414. [
  415. *create_call_method(0), # clear
  416. create_instruction("POP_TOP"),
  417. *create_call_method(1), # update
  418. create_instruction("POP_TOP"),
  419. ]
  420. )
  421. elif self.is_attribute_mutation(var):
  422. for name, value in self.store_attr_mutations.get(
  423. var.mutable_local, {}
  424. ).items():
  425. if isinstance(var, variables.NewGlobalVariable):
  426. cg.tx.output.update_co_names(name)
  427. cg(value)
  428. assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined]
  429. suffixes.append(
  430. [create_instruction("STORE_GLOBAL", argval=name)]
  431. )
  432. elif name == "__call_nn_module_init":
  433. pass # handled in codegen_save_tempvars
  434. elif isinstance(value, variables.DeletedVariable):
  435. if isinstance(
  436. var.mutable_local, AttributeMutationExisting
  437. ) and hasattr(getattr(var, "value", None), name):
  438. cg.tx.output.update_co_names(name)
  439. cg(var.mutable_local.source)
  440. suffixes.append(
  441. [create_instruction("DELETE_ATTR", argval=name)]
  442. )
  443. elif (
  444. isinstance(var, variables.UserDefinedObjectVariable)
  445. and var.needs_slow_setattr()
  446. ):
  447. # __setattr__ is defined on this object, so call object.__setattr__ directly
  448. cg.load_import_from("builtins", "object")
  449. cg.load_method("__setattr__")
  450. cg(var.mutable_local.source) # type: ignore[attr-defined]
  451. cg(variables.ConstantVariable(name))
  452. cg(value)
  453. suffixes.append(
  454. [*create_call_method(3), create_instruction("POP_TOP")]
  455. )
  456. else:
  457. cg.tx.output.update_co_names(name)
  458. cg(value)
  459. cg(var.mutable_local.source)
  460. suffixes.append([create_instruction("STORE_ATTR", argval=name)])
  461. elif isinstance(var, variables.TupleIteratorVariable):
  462. for _ in range(var.index):
  463. cg.load_import_from(utils.__name__, "iter_next")
  464. cg(var.mutable_local.source) # type: ignore[attr-defined]
  465. cg.call_function(1, True)
  466. cg.pop_top()
  467. else:
  468. raise AssertionError(type(var))
  469. # do all the actual mutations at the very end to handle dependencies
  470. for suffix in reversed(suffixes):
  471. cg.extend_output(suffix)
  472. def is_empty(self):
  473. return not (
  474. any(map(self.is_modified, self.id_to_variable.values()))
  475. or self.tensor_hooks
  476. or self.save_for_backward
  477. or self.tensor_hooks
  478. )
  479. def clear(self):
  480. self.keepalive.clear()
  481. self.id_to_variable.clear()