graph_module.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import copy
  4. import itertools
  5. import linecache
  6. import os
  7. import sys
  8. import traceback
  9. import warnings
  10. from pathlib import Path
  11. from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
  12. import torch
  13. import torch.nn as nn
  14. import torch.overrides
  15. from torch.nn.modules.module import _addindent
  16. from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
  17. from ._compatibility import compatibility
  18. from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
  19. __all__ = [
  20. "reduce_graph_module",
  21. "reduce_package_graph_module",
  22. "reduce_deploy_graph_module",
  23. "GraphModule",
  24. ]
  25. _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
  26. # Normal exec loses the source code, however we can work with
  27. # the linecache module to recover it.
  28. # Using _exec_with_source will add it to our local cache
  29. # and then tools like TorchScript will be able to get source info.
  30. class _EvalCacheLoader:
  31. def __init__(self):
  32. self.eval_cache = {}
  33. self.next_id = 0
  34. def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
  35. """Store the source in a private cache, and add a lazy entry in linecache
  36. that allows the source to be retrieved by 'filename'.
  37. Args:
  38. src (str): The module source to cache
  39. globals (dict): The module globals
  40. Returns:
  41. str: The cache key (and dummy filename) generated for src.
  42. """
  43. key = self._get_key()
  44. if co_fields:
  45. key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
  46. self.eval_cache[key] = src
  47. # Don't mutate globals so that this loader is only used
  48. # to populate linecache, and doesn't interact with other modules
  49. # that might check `__loader__`
  50. globals_copy = globals.copy()
  51. globals_copy["__file__"] = key
  52. globals_copy["__name__"] = key
  53. globals_copy["__loader__"] = self
  54. linecache.lazycache(key, globals_copy)
  55. return key
  56. # Part of the loader protocol (PEP 302)
  57. # linecache will use this method when trying to find source code
  58. def get_source(self, module_name) -> Optional[str]:
  59. if module_name in self.eval_cache:
  60. return self.eval_cache[module_name]
  61. return None
  62. def _get_key(self):
  63. key = f"<eval_with_key>.{self.next_id}"
  64. self.next_id += 1
  65. return key
  66. _loader = _EvalCacheLoader()
  67. def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
  68. key = _loader.cache(src, globals, co_fields)
  69. exec(compile(src, key, "exec"), globals)
  70. def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
  71. return _method_from_src(
  72. method_name="forward", src=src, globals=globals, co_fields=co_fields
  73. )
  74. def _method_from_src(
  75. method_name: str, src: str, globals: Dict[str, Any], co_fields=None
  76. ) -> Callable:
  77. # avoid mutating the passed in dict
  78. globals_copy = globals.copy()
  79. _exec_with_source(src, globals_copy, co_fields)
  80. fn = globals_copy[method_name]
  81. del globals_copy[method_name]
  82. return fn
  83. def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
  84. if name in _custom_builtins:
  85. return _custom_builtins[name].import_str
  86. if _is_from_torch(name):
  87. return "import torch"
  88. module_name, attr_name = importer.get_name(obj)
  89. return f"from {module_name} import {attr_name} as {name}"
  90. def _format_import_block(globals: Dict[str, Any], importer: Importer):
  91. import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
  92. # Sort the imports so we have a stable import block that allows us to
  93. # hash the graph module and get a consistent key for use in a cache.
  94. return "\n".join(sorted(import_strs))
  95. @compatibility(is_backward_compatible=True)
  96. def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
  97. # BC: attribute name was changed from `code` to `_code` to facilitate
  98. # making `code` into a property and adding a docstring to it
  99. fn_src = body.get("_code") or body["code"]
  100. forward = _forward_from_src(import_block + fn_src, {})
  101. return _deserialize_graph_module(forward, body)
  102. @compatibility(is_backward_compatible=True)
  103. def reduce_package_graph_module(
  104. importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
  105. ) -> torch.nn.Module:
  106. forward = importer.import_module(generated_module_name).forward
  107. return _deserialize_graph_module(forward, body)
  108. @compatibility(is_backward_compatible=True)
  109. def reduce_deploy_graph_module(
  110. importer: PackageImporter, body: Dict[Any, Any], import_block: str
  111. ) -> torch.nn.Module:
  112. ns = {}
  113. ns["__builtins__"] = importer.patched_builtins
  114. fn_src = body.get("_code")
  115. assert fn_src is not None
  116. forward = _forward_from_src(import_block + fn_src, ns)
  117. return _deserialize_graph_module(forward, body)
  118. # We create a dummy class here because symbolic_trace pulls the forward()
  119. # function off of the class, rather than the instance. This class is used
  120. # in _deserialize_graph_module() below.
  121. class _CodeOnlyModule(torch.nn.Module):
  122. def __init__(self, body):
  123. super().__init__()
  124. self.__dict__ = body
  125. def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module:
  126. """
  127. Deserialize a GraphModule given the dictionary of the original module,
  128. using the code to reconstruct the graph. We delete the actual graph before
  129. saving the dictionary so that changes to the in-memory graph format do not
  130. get serialized.
  131. """
  132. # Try to retrieve the forward source in a backward-compatible way
  133. _CodeOnlyModule.forward = forward
  134. tracer_cls = body.get("_tracer_cls")
  135. if tracer_cls is None:
  136. from ._symbolic_trace import Tracer
  137. tracer_cls = Tracer
  138. graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
  139. # This is a workaround for a mypy linter issue related to
  140. # passing base class as an argument - https://github.com/python/mypy/issues/5865.
  141. cls_tracer: Any = tracer_cls
  142. class KeepModules(cls_tracer):
  143. # we shouldn't trace into any of the submodules,
  144. # because they were not traced in the original GraphModule
  145. def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
  146. return True
  147. com = _CodeOnlyModule(body)
  148. tracer_extras = body.get("_tracer_extras", {})
  149. graph = KeepModules().trace(com, **tracer_extras)
  150. # Manually set Tracer class on the reconstructed Graph, to avoid
  151. # referencing the private local subclass KeepModules.
  152. graph._tracer_cls = tracer_cls
  153. from ._lazy_graph_module import _make_graph_module
  154. gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls)
  155. # The GraphModule constructor only retains attributes referenced by the graph.
  156. # In this case, our goal is return a GraphModule as close to identical as the one
  157. # put into the package. If any additional attributes were present in body,
  158. # we should keep them.
  159. for k, v in body.items():
  160. if not hasattr(gm, k):
  161. setattr(gm, k, v)
  162. return gm
  163. # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
  164. # This installs empty Modules where none exist yet if they are subpaths of target
  165. def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
  166. *prefix, field = target.split(".")
  167. for item in prefix:
  168. f = getattr(from_module, item)
  169. t = getattr(to_module, item, None)
  170. if f is t:
  171. # we have already installed one of its parents
  172. # (e.g. target = root.linear.weight, but we have already installed root.linear)
  173. # once we install a parent, we no longer need to copy the children
  174. # since all the needed properties will already be present
  175. return
  176. if t is None:
  177. t = torch.nn.Module()
  178. setattr(to_module, item, t)
  179. from_module, to_module = f, t
  180. orig = getattr(from_module, field)
  181. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  182. # So, we register it as a named buffer in the target module.
  183. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
  184. to_module.register_buffer(field, orig)
  185. else:
  186. setattr(to_module, field, orig)
  187. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  188. # This installs empty Modules where none exist yet if they are subpaths of target
  189. def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
  190. *prefix, field = target.split(".")
  191. for item in prefix:
  192. t = getattr(to_module, item, None)
  193. if t is None:
  194. t = torch.nn.Module()
  195. setattr(to_module, item, t)
  196. to_module = t
  197. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  198. # So, we register it as a named buffer in the target module.
  199. if isinstance(from_obj, torch.Tensor) and not isinstance(
  200. from_obj, torch.nn.Parameter
  201. ):
  202. to_module.register_buffer(field, from_obj)
  203. else:
  204. setattr(to_module, field, from_obj)
  205. class _WrappedCall:
  206. def __init__(self, cls, cls_call):
  207. self.cls = cls
  208. self.cls_call = cls_call
  209. # Previously, if an error occurred when valid
  210. # symbolically-traced code was run with an invalid input, the
  211. # user would see the source of the error as coming from
  212. # `File "<eval_with_key_N">`, where N is some number. We use
  213. # this function to generate a more informative error message. We
  214. # return the traceback itself, a message explaining that the
  215. # error occurred in a traced Module's generated forward
  216. # function, and five lines of context surrounding the faulty
  217. # line
  218. @staticmethod
  219. def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
  220. # auxiliary variables (for readability)
  221. err_lineno = frame_summary.lineno
  222. assert err_lineno is not None
  223. line = frame_summary.line
  224. assert line is not None
  225. err_line_len = len(line)
  226. all_src_lines = linecache.getlines(frame_summary.filename)
  227. # constituent substrings of the error message
  228. tb_repr = torch._dynamo.disable(traceback.format_exc)()
  229. custom_msg = (
  230. "Call using an FX-traced Module, "
  231. f"line {err_lineno} of the traced Module's "
  232. "generated forward function:"
  233. )
  234. before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
  235. marker = "~" * err_line_len + "~~~ <--- HERE"
  236. err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
  237. # joined message
  238. return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
  239. def __call__(self, obj, *args, **kwargs):
  240. try:
  241. if self.cls_call is not None:
  242. return self.cls_call(obj, *args, **kwargs)
  243. else:
  244. return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
  245. except Exception as e:
  246. assert e.__traceback__
  247. topmost_framesummary: traceback.FrameSummary = (
  248. traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
  249. ) # type: ignore[arg-type]
  250. if "eval_with_key" in topmost_framesummary.filename:
  251. print(
  252. _WrappedCall._generate_error_message(topmost_framesummary),
  253. file=sys.stderr,
  254. )
  255. raise e.with_traceback(None) # noqa: B904
  256. else:
  257. raise e
  258. @compatibility(is_backward_compatible=True)
  259. class GraphModule(torch.nn.Module):
  260. """
  261. GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
  262. ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
  263. from that ``graph``.
  264. .. warning::
  265. When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
  266. regenerated. However, if you edit the contents of the ``graph`` without reassigning
  267. the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
  268. code.
  269. """
  270. def __new__(cls: "Type[GraphModule]", *args, **kwargs):
  271. # each instance of a graph module needs its own forward method
  272. # so create a new singleton class for each instance.
  273. # it is a subclass of the user-defined class, the only difference
  274. # is an extra layer to install the forward method
  275. # address issue described at https://github.com/pytorch/pytorch/issues/63883
  276. # in other words, traverse class hierarchy to fix the redundant class definition problem
  277. for t in cls.__mro__:
  278. c = t.__qualname__.split(".")[-1]
  279. if c != "GraphModuleImpl":
  280. cls = t
  281. break
  282. class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
  283. pass
  284. return super().__new__(GraphModuleImpl)
  285. @compatibility(is_backward_compatible=True)
  286. def __init__(
  287. self,
  288. root: Union[torch.nn.Module, Dict[str, Any]],
  289. graph: Graph,
  290. class_name: str = "GraphModule",
  291. ):
  292. """
  293. Construct a GraphModule.
  294. Args:
  295. root (Union[torch.nn.Module, Dict[str, Any]):
  296. ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
  297. In the case that ``root`` is a Module, any references to Module-based objects (via qualified
  298. name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
  299. within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
  300. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
  301. looked up directly in the dict's keys. The object mapped to by the Dict will be copied
  302. over into the appropriate place within the GraphModule's module hierarchy.
  303. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
  304. class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
  305. error messages will report as originating from ``GraphModule``. It may be helpful to set this
  306. to ``root``'s original name or a name that makes sense within the context of your transform.
  307. """
  308. super().__init__()
  309. self.__class__.__name__ = class_name
  310. if isinstance(root, torch.nn.Module):
  311. if hasattr(root, "training"):
  312. self.training = root.training
  313. # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
  314. if isinstance(root, _CodeOnlyModule):
  315. for k, _ in root.named_children():
  316. _copy_attr(root, self, k)
  317. for k, _ in root.named_buffers():
  318. _copy_attr(root, self, k)
  319. for k, _ in root.named_parameters():
  320. _copy_attr(root, self, k)
  321. for node in graph.nodes:
  322. if node.op in ["get_attr", "call_module"]:
  323. assert isinstance(node.target, str)
  324. _copy_attr(root, self, node.target)
  325. elif isinstance(root, dict):
  326. targets_to_copy = []
  327. for node in graph.nodes:
  328. if node.op in ["get_attr", "call_module"]:
  329. assert isinstance(node.target, str)
  330. if node.target not in root:
  331. raise RuntimeError(
  332. "Node "
  333. + str(node)
  334. + " referenced target "
  335. + node.target
  336. + " but that target was not provided in ``root``!"
  337. )
  338. targets_to_copy.append(node.target)
  339. # Sort targets in ascending order of the # of atoms.
  340. # This will ensure that less deeply nested attributes are assigned
  341. # before more deeply nested attributes. For example, foo.bar
  342. # will be assigned before foo.bar.baz. Otherwise, we might assign
  343. # the user-provided ``foo.bar`` and wipe out the previously-assigned
  344. # ``foo.bar.baz``
  345. targets_to_copy.sort(key=lambda t: t.count("."))
  346. for target_to_copy in targets_to_copy:
  347. _assign_attr(root[target_to_copy], self, target_to_copy)
  348. else:
  349. raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
  350. self.graph = graph
  351. # Store the Tracer class responsible for creating a Graph separately as part of the
  352. # GraphModule state, except when the Tracer is defined in a local namespace.
  353. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  354. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  355. # to re-create the Graph during deserialization.
  356. self._tracer_cls = None
  357. if (
  358. self.graph._tracer_cls
  359. and "<locals>" not in self.graph._tracer_cls.__qualname__
  360. ):
  361. self._tracer_cls = self.graph._tracer_cls
  362. self._tracer_extras = {}
  363. if self.graph._tracer_extras:
  364. self._tracer_extras = self.graph._tracer_extras
  365. # Dictionary to store metadata
  366. self.meta: Dict[str, Any] = {}
  367. self._replace_hook = None
  368. self._create_node_hooks: List[Callable] = []
  369. self._erase_node_hooks: List[Callable] = []
  370. # TorchScript breaks trying to compile the graph setter because of the
  371. # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
  372. #
  373. # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
  374. __jit_unused_properties__ = ["graph"]
  375. @property
  376. def graph(self) -> Graph:
  377. """
  378. Return the ``Graph`` underlying this ``GraphModule``
  379. """
  380. return self._graph
  381. @graph.setter
  382. def graph(self, g: Graph) -> None:
  383. """
  384. Set the underlying ``Graph`` for this ``GraphModule``. This will internally
  385. recompile the ``GraphModule`` so that the generated ``forward()`` function
  386. corresponds to ``g``
  387. """
  388. assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
  389. self._graph = g
  390. g.owning_module = self
  391. self.recompile()
  392. @compatibility(is_backward_compatible=False)
  393. def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
  394. """Dumps out module to ``folder`` with ``module_name`` so that it can be
  395. imported with ``from <folder> import <module_name>``
  396. Args:
  397. folder (Union[str, os.PathLike]): The folder to write the code out to
  398. module_name (str): Top-level name to use for the ``Module`` while
  399. writing out the code
  400. """
  401. folder = Path(folder)
  402. Path(folder).mkdir(exist_ok=True)
  403. torch.save(self.state_dict(), folder / "state_dict.pt")
  404. tab = " " * 4
  405. custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
  406. model_str = f"""
  407. import torch
  408. {custom_builtins}
  409. from torch.nn import *
  410. class {module_name}(torch.nn.Module):
  411. def __init__(self):
  412. super().__init__()
  413. """
  414. def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
  415. safe_reprs = [
  416. nn.Linear,
  417. nn.Conv1d,
  418. nn.Conv2d,
  419. nn.Conv3d,
  420. nn.BatchNorm1d,
  421. nn.BatchNorm2d,
  422. nn.BatchNorm3d,
  423. ]
  424. if type(module) in safe_reprs:
  425. return f"{module.__repr__()}"
  426. else:
  427. return None
  428. blobified_modules = []
  429. for module_name, module in self.named_children():
  430. module_str = _gen_model_repr(module_name, module)
  431. if module_str is None:
  432. module_file = folder / f"{module_name}.pt"
  433. torch.save(module, module_file)
  434. blobified_modules.append(module_name)
  435. module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
  436. module_str = f"torch.load(r'{module_file}') # {module_repr}"
  437. model_str += f"{tab*2}self.{module_name} = {module_str}\n"
  438. for buffer_name, buffer in self._buffers.items():
  439. if buffer is None:
  440. continue
  441. model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
  442. for param_name, param in self._parameters.items():
  443. if param is None:
  444. continue
  445. model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
  446. model_str += (
  447. f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
  448. )
  449. model_str += f"{_addindent(self.code, 4)}\n"
  450. module_file = folder / "module.py"
  451. module_file.write_text(model_str)
  452. init_file = folder / "__init__.py"
  453. init_file.write_text("from .module import *")
  454. if len(blobified_modules) > 0:
  455. warnings.warn(
  456. "Was not able to save the following children modules as reprs -"
  457. f"saved as pickled files instead: {blobified_modules}"
  458. )
  459. @compatibility(is_backward_compatible=True)
  460. def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
  461. """
  462. Adds the given submodule to ``self``.
  463. This installs empty Modules where none exist yet if they are
  464. subpaths of ``target``.
  465. Args:
  466. target: The fully-qualified string name of the new submodule
  467. (See example in ``nn.Module.get_submodule`` for how to
  468. specify a fully-qualified string.)
  469. m: The submodule itself; the actual object we want to
  470. install in the current Module
  471. Return:
  472. bool: Whether or not the submodule could be inserted. For
  473. this method to return True, each object in the chain
  474. denoted by ``target`` must either a) not exist yet,
  475. or b) reference an ``nn.Module`` (not a parameter or
  476. other attribute)
  477. """
  478. *prefix, field = target.split(".")
  479. mod: torch.nn.Module = self
  480. for item in prefix:
  481. submod = getattr(mod, item, None)
  482. if submod is None:
  483. submod = torch.nn.Module()
  484. setattr(mod, item, submod)
  485. if not isinstance(submod, torch.nn.Module):
  486. return False
  487. mod = submod
  488. mod.add_module(field, m)
  489. return True
  490. @compatibility(is_backward_compatible=True)
  491. def delete_submodule(self, target: str) -> bool:
  492. """
  493. Deletes the given submodule from ``self``.
  494. The module will not be deleted if ``target`` is not a valid
  495. target.
  496. Args:
  497. target: The fully-qualified string name of the new submodule
  498. (See example in ``nn.Module.get_submodule`` for how to
  499. specify a fully-qualified string.)
  500. Returns:
  501. bool: Whether or not the target string referenced a
  502. submodule we want to delete. A return value of ``False``
  503. means that the ``target`` was not a valid reference to
  504. a submodule.
  505. """
  506. atoms = target.split(".")
  507. path, target_submod = atoms[:-1], atoms[-1]
  508. mod: torch.nn.Module = self
  509. # Get the parent module
  510. for item in path:
  511. if not hasattr(mod, item):
  512. return False
  513. mod = getattr(mod, item)
  514. if not isinstance(mod, torch.nn.Module):
  515. return False
  516. if not hasattr(mod, target_submod):
  517. return False
  518. if not isinstance(getattr(mod, target_submod), torch.nn.Module):
  519. return False
  520. delattr(mod, target_submod)
  521. return True
  522. @compatibility(is_backward_compatible=True)
  523. def delete_all_unused_submodules(self) -> None:
  524. """
  525. Deletes all unused submodules from ``self``.
  526. A Module is considered "used" if any one of the following is
  527. true:
  528. 1. It has children that are used
  529. 2. Its forward is called directly via a ``call_module`` node
  530. 3. It has a non-Module attribute that is used from a
  531. ``get_attr`` node
  532. This method can be called to clean up an ``nn.Module`` without
  533. manually calling ``delete_submodule`` on each unused submodule.
  534. """
  535. used: List[str] = []
  536. for node in self.graph.nodes:
  537. if node.op == "call_module" or node.op == "get_attr":
  538. # A list of strings representing the different parts
  539. # of the path. For example, `foo.bar.baz` gives us
  540. # ["foo", "bar", "baz"]
  541. fullpath = node.target.split(".")
  542. # If we're looking at multiple parts of a path, join
  543. # join them with a dot. Otherwise, return that single
  544. # element without doing anything to it.
  545. def join_fn(x: str, y: str) -> str:
  546. return ".".join([x, y] if y else [x])
  547. # Progressively collect all the names of intermediate
  548. # modules. For example, if we have the target
  549. # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
  550. # `foo.bar.baz` to the list.
  551. used.extend(itertools.accumulate(fullpath, join_fn))
  552. # For a `call_module` node, also register all recursive submodules
  553. # as used
  554. if node.op == "call_module":
  555. try:
  556. submod = self.get_submodule(node.target)
  557. for submod_name, _ in submod.named_modules():
  558. if submod_name != "":
  559. used.append(".".join([node.target, submod_name]))
  560. except AttributeError:
  561. # Node referenced nonexistent submodule, don't need to
  562. # worry about GCing anything
  563. pass
  564. to_delete = [name for name, _ in self.named_modules() if name not in used]
  565. for name in to_delete:
  566. self.delete_submodule(name)
  567. @property
  568. def code(self) -> str:
  569. """
  570. Return the Python code generated from the ``Graph`` underlying this
  571. ``GraphModule``.
  572. """
  573. if not hasattr(self, "_code"):
  574. raise RuntimeError(
  575. "Code has not been generated! Please report a bug to PyTorch"
  576. )
  577. return self._code
  578. @compatibility(is_backward_compatible=True)
  579. def recompile(self) -> PythonCode:
  580. """
  581. Recompile this GraphModule from its ``graph`` attribute. This should be
  582. called after editing the contained ``graph``, otherwise the generated
  583. code of this ``GraphModule`` will be out of date.
  584. """
  585. if isinstance(self._graph._codegen, _PyTreeCodeGen):
  586. self._in_spec = self._graph._codegen.pytree_info.in_spec
  587. self._out_spec = self._graph._codegen.pytree_info.out_spec
  588. python_code = self._graph.python_code(root_module="self")
  589. self._code = python_code.src
  590. self._lineno_map = python_code._lineno_map
  591. cls = type(self)
  592. co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
  593. cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
  594. # Determine whether this class explicitly defines a __call__ implementation
  595. # to wrap. If it does, save it in order to have wrapped_call invoke it.
  596. # If it does not, wrapped_call can use a dynamic call to super() instead.
  597. # In most cases, super().__call__ should be torch.nn.Module.__call__.
  598. # We do not want to hold a reference to Module.__call__ here; doing so will
  599. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
  600. cls_call = cls.__call__ if "__call__" in vars(cls) else None
  601. if "_wrapped_call" not in vars(cls):
  602. cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
  603. def call_wrapped(self, *args, **kwargs):
  604. return self._wrapped_call(self, *args, **kwargs)
  605. cls.__call__ = call_wrapped # type: ignore[method-assign]
  606. return python_code
  607. # Passing Tracer as argument allows subclasses extending fx.GraphModule
  608. # define their own Tracer (extending fx.Tracer).
  609. def __reduce_deploy__(self, importer: Importer):
  610. dict_without_graph = self.__dict__.copy()
  611. dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
  612. del dict_without_graph["_graph"]
  613. python_code = self.recompile()
  614. import_block = _format_import_block(python_code.globals, importer)
  615. return (reduce_deploy_graph_module, (dict_without_graph, import_block))
  616. def __reduce_package__(self, exporter: PackageExporter):
  617. dict_without_graph = self.__dict__.copy()
  618. dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
  619. del dict_without_graph["_graph"]
  620. generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
  621. python_code = self.recompile()
  622. import_block = _format_import_block(python_code.globals, exporter.importer)
  623. module_code = import_block + self.code
  624. exporter.save_source_string(generated_module_name, module_code)
  625. return (
  626. reduce_package_graph_module,
  627. (dict_without_graph, generated_module_name),
  628. )
  629. def __reduce__(self):
  630. """
  631. Serialization of GraphModule. We serialize only the generated code, not
  632. the underlying ``Graph``. This is because ``Graph`` does not have on-disk
  633. backward-compatibility guarantees, whereas Python source code does.
  634. On the deserialization side, we symbolically trace through the generated
  635. code to regenerate the underlying ``Graph``
  636. """
  637. dict_without_graph = self.__dict__.copy()
  638. python_code = self.recompile()
  639. import_block = _format_import_block(python_code.globals, sys_importer)
  640. del dict_without_graph["_graph"]
  641. return (reduce_graph_module, (dict_without_graph, import_block))
  642. def _deepcopy_init(self):
  643. return GraphModule.__init__
  644. # because __reduce__ is defined for serialization,
  645. # we need to define deepcopy otherwise it will call __reduce__
  646. # and cause symbolic tracing to occur every time we try to copy the object
  647. def __deepcopy__(self, memo):
  648. res = type(self).__new__(type(self))
  649. memo[id(self)] = res
  650. fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
  651. self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
  652. # hooks are lost during `GraphModule.__init__`, so we need to copy over
  653. # them explicitly, note right now we are only copying state_dict related
  654. # hooks, to reduce bc-related issues, we can copy forward/backward related
  655. # hooks in the future as well if needed
  656. extra_preserved_attrs = [
  657. "_state_dict_hooks",
  658. "_load_state_dict_pre_hooks",
  659. "_load_state_dict_post_hooks",
  660. "_replace_hook",
  661. "_create_node_hooks",
  662. "_erase_node_hooks"
  663. ]
  664. for attr in extra_preserved_attrs:
  665. if attr in self.__dict__:
  666. setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
  667. res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
  668. if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
  669. for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
  670. setattr(res, attr_name, attr)
  671. return res
  672. def __copy__(self):
  673. from ._lazy_graph_module import _make_graph_module
  674. res = _make_graph_module(self, self.graph)
  675. res.meta = getattr(self, "meta", {})
  676. return res
  677. @compatibility(is_backward_compatible=False)
  678. def print_readable(self, print_output=True, include_stride=False, include_device=False):
  679. """
  680. Return the Python code generated for current GraphModule and its children GraphModules
  681. """
  682. verbose_python_code = self._graph.python_code(
  683. root_module="self", verbose=True, include_stride=include_stride, include_device=include_device
  684. )
  685. module_code = verbose_python_code.src
  686. module_code = module_code.lstrip("\n")
  687. module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
  688. module_code = _addindent(module_code, 4)
  689. submodule_code_list = [""]
  690. for submodule in self.children():
  691. if isinstance(submodule, GraphModule):
  692. submodule_code_list.append(submodule.print_readable(print_output=False))
  693. submodule_code = "\n".join(submodule_code_list)
  694. submodule_code = _addindent(submodule_code, 4)
  695. output = module_code + submodule_code
  696. if print_output:
  697. print(module_code + submodule_code)
  698. return output
  699. def __str__(self) -> str:
  700. orig_str = super().__str__()
  701. print_readable_reminder = (
  702. "# To see more debug info, please use `graph_module.print_readable()`"
  703. )
  704. return "\n".join([orig_str, self._code, print_readable_reminder])
  705. def _replicate_for_data_parallel(self):
  706. new_gm = self.__copy__()
  707. new_gm._is_replica = True
  708. return new_gm
  709. @contextlib.contextmanager
  710. def _set_replace_hook(self, f):
  711. """
  712. Takes a callable which will be called everytime when we replace a node
  713. to a new node, or change the node's name. Callable takes three arguments:
  714. the old node we're changing, and NAME of the new node, followed by the
  715. user node which consumes the old node to be replaced.
  716. """
  717. assert callable(f), "Replace hook must be a callable."
  718. prev, self._replace_hook = self._replace_hook, f
  719. try:
  720. yield
  721. finally:
  722. self._replace_hook = prev
  723. def _register_create_node_hook(self, f):
  724. """
  725. Takes a callable which will be called after we create a new node. The
  726. callable takes the newly created node as input and returns None.
  727. """
  728. assert callable(f), "create_node hook must be a callable."
  729. self._create_node_hooks.append(f)
  730. def _unregister_create_node_hook(self, f):
  731. """
  732. Takes a callable which was previously registered to be called after we create a node.
  733. This function will unregister that callable so it is no longer invoked on node creation.
  734. """
  735. assert callable(f), "create_node hook must be a callable."
  736. self._create_node_hooks.remove(f)
  737. def _register_erase_node_hook(self, f):
  738. """
  739. Takes a callable which will be called after we erase a node. The
  740. callable takes the node that is being erased as input and returns None.
  741. """
  742. assert callable(f), "erase_node hook must be a callable."
  743. self._erase_node_hooks.append(f)
  744. def _unregister_erase_node_hook(self, f):
  745. """
  746. Takes a callable which was previously registered to be called after we erase a node.
  747. This function will unregister that callable so it is no longer invoked on node erasure.
  748. """
  749. assert callable(f), "erase_node hook must be a callable."
  750. self._erase_node_hooks.remove(f)
  751. # workarounds for issues in __torch_function__
  752. # WAR for __torch_function__ not handling tensor lists,
  753. # fix is in https://github.com/pytorch/pytorch/pull/34725
  754. # orig_cat = torch.cat
  755. # def patched_cat(*args, **kwargs):
  756. # tensors = args[0]
  757. # for t in tensors:
  758. # if isinstance(t, Proxy):
  759. # return t.__torch_function__(patched_cat, (), args, kwargs)
  760. # return orig_cat(*args, **kwargs)
  761. # patched_cat.__module__ = 'torch'
  762. # patched_cat.__name__ = 'cat'
  763. # torch.cat = patched_cat