nn_module.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. # mypy: ignore-errors
  2. import functools
  3. import inspect
  4. import itertools
  5. import types
  6. from contextlib import contextmanager, nullcontext
  7. from typing import Any, Dict, List
  8. import torch.nn
  9. from .. import trace_rules, variables
  10. from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
  11. from ..guards import GuardBuilder, install_guard
  12. from ..mutation_guard import GenerationTracker
  13. from ..source import (
  14. AttrSource,
  15. FSDPNNModuleSource,
  16. GetItemSource,
  17. NNModuleSource,
  18. NotNNModuleSource,
  19. )
  20. from ..utils import (
  21. get_custom_getattr,
  22. get_fake_value,
  23. is_lazy_module,
  24. is_namedtuple,
  25. is_safe_constant,
  26. istensor,
  27. istype,
  28. nnmodule_has_hooks,
  29. object_has_getattribute,
  30. proxy_args_kwargs,
  31. set_example_value,
  32. )
  33. from .base import MutableLocal, typestr, VariableTracker
  34. from .functions import invoke_and_store_as_constant
  35. from .lists import SliceVariable
  36. from .user_defined import UserDefinedObjectVariable
  37. def initialize_lazy_module(tx, mod, args, kwargs):
  38. """
  39. Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.
  40. Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially
  41. useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook
  42. by the time we trace __call__ and thus no graph-break for lazy allowed modules.
  43. """
  44. if hasattr(mod, "_initialize_hook"):
  45. def convert_to_fake(x):
  46. if is_namedtuple(x):
  47. return type(x)(*(convert_to_fake(elem) for elem in x))
  48. elif isinstance(x, dict):
  49. return {k: convert_to_fake(v) for k, v in x.items()}
  50. elif isinstance(x, (list, tuple, set)):
  51. return type(x)(convert_to_fake(elem) for elem in x)
  52. elif isinstance(x, torch.fx.Proxy):
  53. return get_fake_value(x.node, tx)
  54. else:
  55. return x
  56. proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
  57. fake_args = [convert_to_fake(arg) for arg in proxy_args]
  58. fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()}
  59. mod._infer_parameters(mod, fake_args, fake_kwargs)
  60. def cleanup_source_for_nn_module_stack(source):
  61. # TODO(anijain2305, export-team) This is a bad hack to fix the nn module
  62. # fully_qualified_name to work with export/unflatten. It converts
  63. # mod._modules['net1'] to mod.net1.
  64. # This type of source occurs when we use UnspecializedNNModule variable
  65. # because unspecialized nn module variable inlines module __getattr__ calls.
  66. # For export, we rely heavily on NNModuleVariable and do not support
  67. # UnspecializedNNModule. But there is one case where this gets exposed -
  68. # Pippy. Pippy uses export/unflatten (an export feature) and also
  69. # monkepatches the `forward` method of a mod that forces Dynamo to use
  70. # UnspecializedNNModule. Therefore, we will need proper work to retain the
  71. # nn module stack when we let export rely on UnspecializedNNModule variable.
  72. # This does not work if we have recursively UnspecializedNNModule variables
  73. # e.g. mod._modules['net1']._modules['net2']. This is unlikely to happen in
  74. # Pippy so the hotfix is enough for Pippy.
  75. if (
  76. isinstance(source, GetItemSource)
  77. and isinstance(source.base, AttrSource)
  78. and isinstance(source.base.base, NNModuleSource)
  79. and source.base.member == "_modules"
  80. ):
  81. return AttrSource(source.base.base, source.index)
  82. return source
  83. @contextmanager
  84. def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
  85. source_for_nn_module_stack = cleanup_source_for_nn_module_stack(source)
  86. fully_qualified_name = source_for_nn_module_stack.name()
  87. try:
  88. tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__)
  89. yield
  90. finally:
  91. del tx.nn_module_stack[module_key]
  92. def guard_to_detect_forward_monkeypatching(source, mod):
  93. # Users sometimes patch the forward method of a nn module instance to
  94. # perform optimizations like quantization. Though this is not a good
  95. # software practice, but python allows this and Dynamo needs to detect
  96. # this patching.
  97. #
  98. # One way to do this is to add an ID_MATCH guard on every function
  99. # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But
  100. # this increased guard overhead by around 20%.
  101. #
  102. # To keep the guard overhead down, we just guard on the `forward` being
  103. # not present in the mod __dict__. The common case of patching forward
  104. # method adds `forward` in the instance __dict__, whereas the unpatched
  105. # `forward` sits in the type(mod).__dict__
  106. if source:
  107. if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]):
  108. # Monkeypatched forward method, add an ID_MATCH guard on forward function
  109. fwd = mod.__dict__["forward"]
  110. forward_source = AttrSource(source, "forward")
  111. if type(fwd) is types.MethodType:
  112. forward_source = AttrSource(forward_source, "__func__")
  113. install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH))
  114. else:
  115. # Common case - check that the forward key is absent in mod __dict__
  116. install_guard(
  117. source.make_guard(
  118. functools.partial(
  119. GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward"
  120. )
  121. )
  122. )
  123. class NNModuleVariable(VariableTracker):
  124. _nonvar_fields = {
  125. "module_type",
  126. "module_key",
  127. "module",
  128. *VariableTracker._nonvar_fields,
  129. }
  130. def __init__(
  131. self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
  132. ):
  133. super().__init__(**kwargs)
  134. self.module_type = module_type
  135. self.module_key = module_key
  136. self.module = module
  137. assert self.source
  138. def python_type(self):
  139. return self.module_type
  140. def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
  141. return
  142. def unpack_var_sequence(self, tx):
  143. # implement list/iter/tuple/etc calls
  144. base = tx.output.get_submodule(self.module_key)
  145. if isinstance(base, torch.nn.ModuleDict):
  146. result = []
  147. for name, submod in base.items():
  148. name_var = variables.ConstantVariable.create(name)
  149. tx.output.register_attr_or_module(
  150. submod,
  151. self.module_key,
  152. name,
  153. source=NNModuleSource(GetItemSource(self.source, name)),
  154. )
  155. result.append(name_var)
  156. return result
  157. assert isinstance(
  158. base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
  159. ), typestr(base)
  160. assert self.source
  161. result = []
  162. for idx, submod in enumerate(base):
  163. result.append(
  164. tx.output.register_attr_or_module(
  165. submod,
  166. self.module_key,
  167. idx,
  168. source=NNModuleSource(GetItemSource(self.source, idx)),
  169. )
  170. )
  171. return result
  172. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  173. mod = tx.output.get_submodule(self.module_key)
  174. result = hasattr(mod, name)
  175. install_guard(
  176. NNModuleSource(AttrSource(self.source, name)).make_guard(
  177. GuardBuilder.HASATTR
  178. )
  179. )
  180. return variables.ConstantVariable.create(result)
  181. def is_training(self, tx):
  182. mod = tx.output.get_submodule(self.module_key)
  183. return getattr(mod, "training", False)
  184. def convert_to_unspecialized(self, tx):
  185. """Restart analysis treating this module as an UnspecializedNNModuleVariable"""
  186. mod = tx.output.get_submodule(self.module_key)
  187. GenerationTracker.tag(mod)
  188. # Mark the class dynamic unless its module initialization
  189. if tx.f_code.co_name != "__init__":
  190. GenerationTracker.mark_class_dynamic(type(mod))
  191. raise UnspecializeRestartAnalysis
  192. def has_key_in_generic_dict(self, tx, key):
  193. base = tx.output.get_submodule(self.module_key)
  194. if object_has_getattribute(base):
  195. unimplemented("NNModuleVariable with custom __getattribute__")
  196. if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
  197. mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
  198. return not isinstance(mutated_attr, variables.DeletedVariable)
  199. base_dict = object.__getattribute__(base, "__dict__")
  200. return key in base_dict
  201. def _custom_getattr_fallback(self, base, tx, name, options):
  202. """Check for a __getattr__ and handle it specially if it is implemented"""
  203. if object_has_getattribute(base):
  204. unimplemented("torch.nn.Module with a custom __getattribute__ defined")
  205. getattr_fn = get_custom_getattr(base)
  206. if getattr_fn is None:
  207. return None
  208. if not isinstance(getattr_fn, types.FunctionType):
  209. unimplemented("torch.nn.Module with a non-function custom __getattr__")
  210. return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
  211. tx, [variables.ConstantVariable.create(name)], {}
  212. )
  213. def var_getattr(self, tx, name):
  214. from .builder import VariableBuilder
  215. if self.source:
  216. source = AttrSource(self.source, name)
  217. else:
  218. source = None
  219. base = tx.output.get_submodule(self.module_key)
  220. base_dict = object.__getattribute__(base, "__dict__")
  221. object_member = True
  222. all_class_attribute_names = set()
  223. for x in inspect.getmro(base.__class__):
  224. all_class_attribute_names.update(x.__dict__.keys())
  225. if not self.source:
  226. unimplemented("GETATTR with no source")
  227. if name == "__dict__":
  228. return variables.GetAttrVariable(self, name, source=source)
  229. if name in base_dict:
  230. subobj = base_dict[name]
  231. elif (
  232. "_modules" in base_dict
  233. and name in base_dict["_modules"]
  234. and name not in all_class_attribute_names
  235. ):
  236. subobj = base_dict["_modules"][name]
  237. elif "_parameters" in base_dict and name in base_dict["_parameters"]:
  238. subobj = base_dict["_parameters"][name]
  239. elif "_buffers" in base_dict and name in base_dict["_buffers"]:
  240. subobj = base_dict["_buffers"][name]
  241. else:
  242. try:
  243. subobj = inspect.getattr_static(base, name)
  244. object_member = False
  245. except AttributeError:
  246. # see if we can fallback to __getattr__, which is not checked by getattr_static
  247. result = self._custom_getattr_fallback(
  248. base=base, tx=tx, name=name, options={"source": source}
  249. )
  250. if result is not None:
  251. return result
  252. # if we can't find a __getattr__, just raise the AttributeError
  253. raise
  254. if name == "forward":
  255. guard_to_detect_forward_monkeypatching(self.source, base)
  256. if name == "__class__" and not object_member:
  257. return variables.UserDefinedClassVariable(base.__class__, source=source)
  258. if object_member:
  259. return VariableBuilder(tx, NNModuleSource(source))(subobj)
  260. else:
  261. if istype(subobj, property):
  262. if self.source:
  263. # Read the class attribute to reach the property
  264. source = AttrSource(AttrSource(self.source, "__class__"), name)
  265. # Get the getter function
  266. source = AttrSource(source, "fget")
  267. return variables.UserFunctionVariable(
  268. subobj.fget,
  269. source=source,
  270. ).call_function(tx, [(self)], {})
  271. elif istype(subobj, classmethod):
  272. return variables.UserMethodVariable(
  273. subobj.__func__,
  274. variables.UserDefinedObjectVariable(type(base)),
  275. source=source,
  276. )
  277. elif istype(subobj, staticmethod):
  278. return variables.UserFunctionVariable(
  279. subobj.__get__(base), source=source
  280. )
  281. elif istype(subobj, types.FunctionType):
  282. return variables.UserMethodVariable(subobj, self, source=source)
  283. elif is_safe_constant(subobj) or istensor(subobj):
  284. # Support possibly common cases of class members
  285. return VariableBuilder(tx, NNModuleSource(source))(subobj)
  286. else:
  287. unimplemented(
  288. f"class property {name} - {typestr(base)} {typestr(subobj)}"
  289. )
  290. return variables.GetAttrVariable(self, name, source=source)
  291. def call_function(
  292. self,
  293. tx,
  294. args: "List[VariableTracker]",
  295. kwargs: "Dict[str, VariableTracker]",
  296. ) -> "VariableTracker":
  297. mod = tx.output.get_submodule(self.module_key)
  298. with record_nn_module_stack(self.module_key, self.source, tx, mod):
  299. is_lazy = is_lazy_module(mod)
  300. if (
  301. isinstance(mod, torch.nn.Sequential)
  302. and mod.__class__.forward is torch.nn.Sequential.forward
  303. ):
  304. if nnmodule_has_hooks(mod):
  305. # We do not want to unroll sequential if it has hooks, since evaporating it
  306. # will cause hooks to not fire!
  307. # This terminates and restart the tracing process
  308. self.convert_to_unspecialized(tx)
  309. # Unroll sequential
  310. assert (
  311. not is_lazy
  312. ), "Expected lazy sequential isn't a valid combination?"
  313. assert not kwargs
  314. (arg,) = args
  315. # TODO: Use named_children when it supports remove_duplicate=False.
  316. for child_name, submod in mod._modules.items():
  317. tx.call_function(
  318. tx.output.register_attr_or_module(
  319. submod,
  320. self.module_key,
  321. child_name,
  322. source=NNModuleSource(AttrSource(self.source, child_name)),
  323. ),
  324. [arg],
  325. {},
  326. )
  327. arg = tx.pop()
  328. return arg
  329. if is_lazy:
  330. # The module type will change after it is called
  331. if mod.cls_to_become is not None:
  332. self.module_type = mod.cls_to_become
  333. # The pre-hook runs to initialize the module shapes, then deletes itself. After this,
  334. # the module is more or less not lazy and can be treated as a normal module regardless of
  335. # is_allowed or other variations.
  336. initialize_lazy_module(tx, mod, args, kwargs)
  337. # If we are tracing the higher order op, we want Dynamo to step
  338. # inside the module call so that Dynamo can see the underlying
  339. # parameters and buffers and raise them as inputs to the graph.
  340. #
  341. # NB: torch.nn.utils.parametrize changes the class type of a
  342. # parametrized module such that its __module__ points to
  343. # "torch.nn.utils.parametrize".
  344. if (
  345. tx.output.is_root_tracer()
  346. and mod.__module__.startswith(("torch.nn.", "torch.ao."))
  347. and mod.__module__ != "torch.nn.utils.parametrize"
  348. ):
  349. if nnmodule_has_hooks(
  350. mod, check_forward_hooks=True, check_backward_hooks=True
  351. ):
  352. # End of fn, this bubbles up and restarts tracing.
  353. self.convert_to_unspecialized(tx)
  354. from .builder import wrap_fx_proxy
  355. return wrap_fx_proxy(
  356. tx=tx,
  357. proxy=tx.output.create_proxy(
  358. "call_module",
  359. self.module_key,
  360. *proxy_args_kwargs(args, kwargs),
  361. ),
  362. )
  363. else:
  364. assert self.source, (
  365. "Must provide a valid source in order to inline, "
  366. "since inlined function may have default args which must be guarded."
  367. )
  368. if isinstance(mod, torch.fx.GraphModule):
  369. # TODO: do we want to support __call__ for GM's?
  370. # If so at least some changes are needed, we don't allow inlining
  371. # the call_wrapped currently, and maybe other issues too
  372. fn = mod.forward
  373. fn_source = AttrSource(self.source, "forward")
  374. else:
  375. fn = mod._call_impl
  376. fn_source = AttrSource(self.source, "_call_impl")
  377. if istype(fn, types.MethodType):
  378. fn = fn.__func__
  379. fn_source = AttrSource(fn_source, "__func__")
  380. args = [self] + args
  381. else:
  382. assert istype(fn, types.FunctionType)
  383. return tx.inline_user_function_return(
  384. variables.UserFunctionVariable(fn, source=fn_source),
  385. args,
  386. kwargs,
  387. )
  388. def call_method(
  389. self,
  390. tx,
  391. name,
  392. args: "List[VariableTracker]",
  393. kwargs: "Dict[str, VariableTracker]",
  394. constant=False,
  395. ) -> "VariableTracker":
  396. from . import ConstantVariable, ListIteratorVariable, TupleVariable
  397. key = self.module_key
  398. module = tx.output.get_submodule(key)
  399. def generic_call_method_helper(name):
  400. # Helper function to put a `call_method` node in FX graph,
  401. # with nn.Module as the first arg.
  402. mod_proxy = tx.output.create_proxy(
  403. "get_attr",
  404. self.module_key,
  405. tuple(),
  406. {},
  407. )
  408. set_example_value(mod_proxy.node, module)
  409. proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
  410. from .builder import wrap_fx_proxy
  411. return wrap_fx_proxy(
  412. tx=tx,
  413. proxy=tx.output.create_proxy(
  414. "call_method",
  415. name,
  416. args=(mod_proxy, *proxy_args),
  417. kwargs=proxy_kwargs,
  418. ),
  419. )
  420. if name in ["_call_impl", "_wrapped_call_impl"]:
  421. # Example: `self.layer.__call__(x)`
  422. # This is used for explicit calling `__call__` in a forward function.
  423. # Dynamo inlines `__call__`, includes hooks.
  424. return self.call_function(tx, args, kwargs)
  425. elif name == "forward":
  426. # Example: `self.layer.forward(x)`
  427. # This is used for explicit calling `forward` in a forward function.
  428. # Dynamo puts `call_method` node in FX, doesn't trigger hooks.
  429. with record_nn_module_stack(self.module_key, self.source, tx, module):
  430. return generic_call_method_helper(name)
  431. if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed(
  432. inspect.getfile(module.__class__._check_input_dim)
  433. ):
  434. return ConstantVariable.create(True)
  435. if name == "_get_item_by_idx":
  436. assert args[1].is_python_constant()
  437. assert isinstance(args[0], TupleVariable)
  438. mod_var = args[0].items[args[1].value]
  439. if isinstance(mod_var, UnspecializedNNModuleVariable):
  440. return mod_var
  441. key = mod_var.module_key
  442. submod = tx.output.get_submodule(key)
  443. return tx.output.register_attr_or_module(
  444. submod,
  445. key,
  446. key,
  447. source=NNModuleSource(GetItemSource(self.source, key)),
  448. )
  449. if constant:
  450. fn = getattr(module, name)
  451. name = f"{module.__class__.__name__}_{name}_result"
  452. return invoke_and_store_as_constant(tx, fn, name, args, kwargs)
  453. def assert_all_args_kwargs_const():
  454. if not all(
  455. x.is_python_constant() for x in itertools.chain(args, kwargs.values())
  456. ):
  457. unimplemented(f"non-const NNModule method {name}")
  458. def get_kwargs(*names):
  459. assert_all_args_kwargs_const()
  460. fn = getattr(module, name)
  461. bound_args = inspect.signature(fn).bind(
  462. *([x.as_python_constant() for x in args]),
  463. **{k: v.as_python_constant() for k, v in kwargs.items()},
  464. )
  465. bound_args.apply_defaults()
  466. bound_args = bound_args.arguments
  467. return {k: bound_args[k] for k in names}
  468. def wrap_values(items):
  469. result = []
  470. for name, submod in items:
  471. result.append(
  472. tx.output.register_attr_or_module(
  473. submod,
  474. key,
  475. name,
  476. source=NNModuleSource(gen_source(self.source, name)),
  477. )
  478. )
  479. return ListIteratorVariable(result, mutable_local=MutableLocal())
  480. def named_embed(name, obj):
  481. return TupleVariable(
  482. [
  483. ConstantVariable.create(name),
  484. tx.output.register_attr_or_module(
  485. obj,
  486. key,
  487. name,
  488. source=NNModuleSource(gen_source(self.source, name)),
  489. ),
  490. ]
  491. )
  492. def gen_source(source, name):
  493. name_split = name.split(".")
  494. if name_split[0] == "":
  495. return source
  496. while len(name_split) > 0:
  497. x = name_split.pop(0)
  498. source = AttrSource(source, x)
  499. return source
  500. if name == "named_children":
  501. tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
  502. assert not (args or kwargs)
  503. result = []
  504. for name, submod in module.named_children():
  505. result.append(named_embed(name, submod))
  506. return ListIteratorVariable(result, mutable_local=MutableLocal())
  507. elif name == "named_parameters":
  508. tx.output.guard_on_key_order.add(
  509. AttrSource(self.source, "_parameters").name()
  510. )
  511. result = []
  512. for name, param in module.named_parameters(
  513. **get_kwargs("prefix", "recurse")
  514. ):
  515. result.append(named_embed(name, param))
  516. return ListIteratorVariable(result, mutable_local=MutableLocal())
  517. elif name == "named_buffers":
  518. tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
  519. result = []
  520. for name, buffer in module.named_buffers(
  521. **get_kwargs("prefix", "recurse", "remove_duplicate")
  522. ):
  523. result.append(named_embed(name, buffer))
  524. return ListIteratorVariable(result, mutable_local=MutableLocal())
  525. elif name == "named_modules":
  526. tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
  527. result = []
  528. for name, submod in module.named_modules(
  529. **get_kwargs("memo", "prefix", "remove_duplicate")
  530. ):
  531. result.append(named_embed(name, submod))
  532. return ListIteratorVariable(result, mutable_local=MutableLocal())
  533. elif name == "children":
  534. tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
  535. assert not (args or kwargs)
  536. return wrap_values(module.named_children())
  537. elif name == "modules":
  538. tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
  539. return wrap_values(module.named_modules())
  540. elif name == "parameters":
  541. tx.output.guard_on_key_order.add(
  542. AttrSource(self.source, "_parameters").name()
  543. )
  544. return wrap_values(module.named_parameters(**get_kwargs("recurse")))
  545. elif name == "buffers":
  546. tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
  547. return wrap_values(module.named_buffers(**get_kwargs("recurse")))
  548. elif name == "keys":
  549. assert not (args or kwargs)
  550. result = []
  551. for name in module.keys():
  552. result.append(ConstantVariable.create(name))
  553. return ListIteratorVariable(result, mutable_local=MutableLocal())
  554. elif name == "values":
  555. assert not (args or kwargs)
  556. return wrap_values(module.items())
  557. elif name == "items":
  558. assert not (args or kwargs)
  559. result = []
  560. for name, submod in module.items():
  561. result.append(named_embed(name, submod))
  562. return ListIteratorVariable(result, mutable_local=MutableLocal())
  563. elif name == "__len__":
  564. assert not (args or kwargs)
  565. return ConstantVariable.create(len(module))
  566. elif (
  567. name == "__contains__"
  568. and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
  569. and args
  570. and args[0].is_python_constant()
  571. ):
  572. return ConstantVariable.create(
  573. args[0].as_python_constant() in module._modules
  574. )
  575. elif name == "__getitem__":
  576. assert not kwargs and len(args) == 1
  577. builtin_supported = (
  578. torch.nn.ModuleDict.__getitem__,
  579. torch.nn.ModuleList.__getitem__,
  580. torch.nn.ParameterDict.__getitem__,
  581. torch.nn.ParameterList.__getitem__,
  582. torch.nn.Sequential.__getitem__,
  583. )
  584. if type(module).__getitem__ not in builtin_supported:
  585. assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
  586. key = args[0].as_python_constant()
  587. assert isinstance(key, (str, int))
  588. fn = getattr(module, name).__func__
  589. assert isinstance(fn, types.FunctionType)
  590. src = AttrSource(AttrSource(self.source, name), "__func__")
  591. return tx.inline_user_function_return(
  592. variables.UserFunctionVariable(fn, source=src),
  593. [self] + list(args),
  594. kwargs,
  595. )
  596. assert self.source
  597. if isinstance(args[0], SliceVariable):
  598. # Build a TupleVariable of NNModules
  599. result = []
  600. submods = []
  601. # Turn the slice into the list of integers
  602. keys = list(range(len(module)))[args[0].as_python_constant()]
  603. for idx, submod in enumerate(module[args[0].as_python_constant()]):
  604. key = keys[idx]
  605. src = NNModuleSource(GetItemSource(self.source, key))
  606. result.append(
  607. tx.output.register_attr_or_module(
  608. submod,
  609. key,
  610. source=src,
  611. )
  612. )
  613. submods.append(submod)
  614. new_module = torch.nn.Sequential(*submods)
  615. new_module_variable = tx.output.register_attr_or_module(
  616. new_module,
  617. f"{self}.__getitem__(slice)",
  618. source=NNModuleSource(
  619. GetItemSource(self.source, args[0].as_python_constant())
  620. ),
  621. )
  622. return new_module_variable
  623. from .tensor import SymNodeVariable
  624. if isinstance(args[0], SymNodeVariable):
  625. key = args[0].evaluate_expr(tx.output)
  626. else:
  627. key = args[0].as_python_constant()
  628. submod = module[key]
  629. return tx.output.register_attr_or_module(
  630. submod,
  631. self.module_key,
  632. key,
  633. source=NNModuleSource(GetItemSource(self.source, key)),
  634. )
  635. elif (
  636. name == "_get_abs_string_index"
  637. or (
  638. isinstance(module, torch.nn.modules.conv._ConvNd)
  639. and name == "_conv_forward"
  640. )
  641. or (
  642. isinstance(module, torch.nn.modules.conv._ConvTransposeNd)
  643. and name == "_output_padding"
  644. )
  645. ):
  646. # Inline the function
  647. fn = getattr(module, name).__func__
  648. fn_source = AttrSource(AttrSource(self.source, name), "__func__")
  649. return tx.inline_user_function_return(
  650. variables.UserFunctionVariable(fn, source=fn_source),
  651. [self] + args,
  652. kwargs,
  653. )
  654. # A loose heuristic, but seems to be generally good before we drop into the
  655. # manual handling of inputs
  656. elif (
  657. name in module.__class__.__dict__
  658. and callable(module.__class__.__dict__[name])
  659. and all(
  660. isinstance(x, variables.TensorVariable)
  661. for x in itertools.chain(args, kwargs.values())
  662. )
  663. ):
  664. return generic_call_method_helper(name)
  665. else:
  666. return super().call_method(tx, name, args, kwargs)
  667. class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
  668. _nonvar_fields = {
  669. "value_type",
  670. "is_state_mutated",
  671. *UserDefinedObjectVariable._nonvar_fields,
  672. }
  673. """
  674. The above class will specialize on the id() of a module and place
  675. parameters on the torch.fx.GraphModule. Giving one graph per
  676. module instance. This version treats nn.Modules() like other user
  677. defined objects and will pass parameters into the FX graph as inputs.
  678. Giving one graph per module class.
  679. """
  680. def __init__(self, value, **kwargs):
  681. if type(value) is torch.jit._script.RecursiveScriptModule:
  682. raise Unsupported(
  683. "ScriptModules aren't supported in UnspecializedNNModuleVariable"
  684. " becuase their .forward function isn't a static member of their type"
  685. )
  686. if "value_type" in kwargs:
  687. lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None)
  688. if type(value) is lazy_value_to_become:
  689. # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects)
  690. # and then later we called and mutated the LazyModule into a MaterializedModule.
  691. # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only
  692. # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation.
  693. kwargs["value_type"] = type(value)
  694. super().__init__(value=value, **kwargs)
  695. self.is_state_mutated = False
  696. @staticmethod
  697. @functools.lru_cache(None)
  698. def _nn_module_method_ids():
  699. # Allow __setattr__ to fall through to base class handler
  700. supported = {torch.nn.Module.__setattr__}
  701. return {
  702. id(x.__code__)
  703. for x in torch.nn.Module.__dict__.values()
  704. if hasattr(x, "__code__") and x not in supported
  705. }
  706. def unpack_var_sequence(self, tx):
  707. from .builder import VariableBuilder
  708. try:
  709. fn = inspect.getattr_static(self.value_type, "__iter__")
  710. except AttributeError as e:
  711. raise NotImplementedError from e
  712. if fn in (
  713. torch.nn.ModuleList.__iter__,
  714. torch.nn.ParameterList.__iter__,
  715. torch.nn.Sequential.__iter__,
  716. ):
  717. assert self.source
  718. return [
  719. VariableBuilder(tx, source=GetItemSource(self.source, idx))(item)
  720. for idx, item in enumerate(self.value)
  721. ]
  722. return super().unpack_var_sequence(tx)
  723. def call_function(
  724. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  725. ) -> "VariableTracker":
  726. mod = self.value
  727. # see comment on lazy module handling in NNModuleVariable.call_function for context
  728. if is_lazy_module(mod):
  729. if mod.cls_to_become is not None:
  730. self.value_type = mod.cls_to_become
  731. initialize_lazy_module(tx, mod, args, kwargs)
  732. name = "_call_impl"
  733. fn = getattr(self.value_type, name)
  734. if self.source:
  735. source = AttrSource(AttrSource(self.source, "__class__"), name)
  736. else:
  737. source = None
  738. guard_to_detect_forward_monkeypatching(self.source, mod)
  739. ctx = (
  740. record_nn_module_stack(str(id(mod)), self.source, tx, mod)
  741. if self.source
  742. else nullcontext()
  743. )
  744. with ctx:
  745. return variables.UserFunctionVariable(fn, source=source).call_function(
  746. tx, [self] + list(args), kwargs
  747. )
  748. def call_method(
  749. self,
  750. tx,
  751. name,
  752. args: "List[VariableTracker]",
  753. kwargs: "Dict[str, VariableTracker]",
  754. ) -> "VariableTracker":
  755. from .builder import VariableBuilder
  756. if name in ["_call_impl", "_wrapped_call_impl"]:
  757. fn = getattr(self.value_type, name)
  758. if self.source:
  759. source = AttrSource(AttrSource(self.source, "__class__"), name)
  760. else:
  761. source = None
  762. return variables.UserFunctionVariable(fn, source=source).call_function(
  763. tx, [self] + list(args), kwargs
  764. )
  765. if name not in getattr(self.value, "__dict__", {}):
  766. try:
  767. method = inspect.getattr_static(type(self.value), name)
  768. except AttributeError:
  769. method = None
  770. if method is torch.nn.Module.parameters:
  771. assert not args or kwargs
  772. if tx.output.side_effects.has_pending_mutation(self):
  773. unimplemented("Module.parameters() with pending mutation")
  774. install_guard(
  775. self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
  776. )
  777. items = []
  778. for name, value in self.value.named_parameters():
  779. items.append(
  780. VariableBuilder(tx, AttrSource(self.source, name))(value)
  781. )
  782. return variables.ListIteratorVariable(
  783. items, mutable_local=MutableLocal()
  784. )
  785. elif isinstance(method, staticmethod):
  786. source = AttrSource(
  787. AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
  788. )
  789. return tx.inline_user_function_return(
  790. variables.UserFunctionVariable(method.__func__, source=source),
  791. args,
  792. kwargs,
  793. )
  794. if (
  795. hasattr(method, "__code__")
  796. and id(method.__code__) in self._nn_module_method_ids()
  797. ):
  798. unimplemented(f"UnspecializedNNModuleVariable missing {name}")
  799. # "_parameters" in self.value.__dict__ checks that module is initialized
  800. if name == "__setattr__" and "_parameters" in self.value.__dict__:
  801. # Record if mutations happens on parameters/buffers/modules. The
  802. # mutations on these are not tracked by base class
  803. # UserDefinedObject vt. This will be used later to graph break
  804. # on seeing a paramters() and family calls.
  805. # TODO(anijain2305) - This might not be needed if we let Dynamo
  806. # inline both getattr and setattr. In that case, it should see
  807. # the lowest level dicts - _parameters and family and
  808. # automatically track mutations on those. Investigate if that
  809. # can be done.
  810. attr_name = args[0].as_python_constant()
  811. value = args[1]
  812. # This is reverse engineered by looking at nn module __setattr__
  813. # logic.
  814. if (
  815. isinstance(value, variables.TensorVariable)
  816. and value.python_type() is torch.nn.Parameter
  817. ) or attr_name in self.value.__dict__["_parameters"]:
  818. # Handle parameters
  819. self.is_state_mutated = True
  820. elif attr_name in self.value.__dict__["_buffers"]:
  821. # Handle buffers
  822. self.is_state_mutated = True
  823. elif (
  824. isinstance(
  825. value,
  826. (
  827. variables.NNModuleVariable,
  828. variables.UnspecializedNNModuleVariable,
  829. ),
  830. )
  831. or attr_name in self.value.__dict__["_modules"]
  832. ):
  833. # Handle submodules
  834. self.is_state_mutated = True
  835. return super().call_method(tx, name, args, kwargs)
  836. class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
  837. """
  838. Tracing behavior: trace into submodules and treat them as Unspecialized, do not
  839. register parameters to the top-level, treat them as function inputs.
  840. Guards behavior: if 'skip_fsdp_guards', many guards that would be installed
  841. by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis
  842. that a user wrapping their model in FSDP(model) is already opting into a
  843. requirement to not modify internal model state, which would already break FSDP without
  844. compilation.
  845. """
  846. def __init__(self, value, **kwargs):
  847. source = kwargs.get("source", None)
  848. assert (
  849. source is not None
  850. ), "FSDPManagedNNModule depends on having an accurate source to control guarding."
  851. super().__init__(value=value, **kwargs)
  852. self.source = source
  853. @staticmethod
  854. def _wrap_source(source):
  855. if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)):
  856. if torch._dynamo.config.skip_fsdp_guards:
  857. return FSDPNNModuleSource(source)
  858. else:
  859. # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
  860. return NotNNModuleSource(source)
  861. else:
  862. return source
  863. def __setattr__(self, name: str, value: Any) -> None:
  864. if name == "source":
  865. value = FSDPManagedNNModuleVariable._wrap_source(value)
  866. return super().__setattr__(name, value)