| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920 |
- # mypy: allow-untyped-defs
- import contextlib
- import copy
- import itertools
- import linecache
- import os
- import sys
- import traceback
- import warnings
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
- import torch
- import torch.nn as nn
- import torch.overrides
- from torch.nn.modules.module import _addindent
- from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
- from ._compatibility import compatibility
- from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
- __all__ = [
- "reduce_graph_module",
- "reduce_package_graph_module",
- "reduce_deploy_graph_module",
- "GraphModule",
- ]
- _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
- # Normal exec loses the source code, however we can work with
- # the linecache module to recover it.
- # Using _exec_with_source will add it to our local cache
- # and then tools like TorchScript will be able to get source info.
- class _EvalCacheLoader:
- def __init__(self):
- self.eval_cache = {}
- self.next_id = 0
- def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
- """Store the source in a private cache, and add a lazy entry in linecache
- that allows the source to be retrieved by 'filename'.
- Args:
- src (str): The module source to cache
- globals (dict): The module globals
- Returns:
- str: The cache key (and dummy filename) generated for src.
- """
- key = self._get_key()
- if co_fields:
- key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
- self.eval_cache[key] = src
- # Don't mutate globals so that this loader is only used
- # to populate linecache, and doesn't interact with other modules
- # that might check `__loader__`
- globals_copy = globals.copy()
- globals_copy["__file__"] = key
- globals_copy["__name__"] = key
- globals_copy["__loader__"] = self
- linecache.lazycache(key, globals_copy)
- return key
- # Part of the loader protocol (PEP 302)
- # linecache will use this method when trying to find source code
- def get_source(self, module_name) -> Optional[str]:
- if module_name in self.eval_cache:
- return self.eval_cache[module_name]
- return None
- def _get_key(self):
- key = f"<eval_with_key>.{self.next_id}"
- self.next_id += 1
- return key
- _loader = _EvalCacheLoader()
- def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
- key = _loader.cache(src, globals, co_fields)
- exec(compile(src, key, "exec"), globals)
- def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
- return _method_from_src(
- method_name="forward", src=src, globals=globals, co_fields=co_fields
- )
- def _method_from_src(
- method_name: str, src: str, globals: Dict[str, Any], co_fields=None
- ) -> Callable:
- # avoid mutating the passed in dict
- globals_copy = globals.copy()
- _exec_with_source(src, globals_copy, co_fields)
- fn = globals_copy[method_name]
- del globals_copy[method_name]
- return fn
- def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
- if name in _custom_builtins:
- return _custom_builtins[name].import_str
- if _is_from_torch(name):
- return "import torch"
- module_name, attr_name = importer.get_name(obj)
- return f"from {module_name} import {attr_name} as {name}"
- def _format_import_block(globals: Dict[str, Any], importer: Importer):
- import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
- # Sort the imports so we have a stable import block that allows us to
- # hash the graph module and get a consistent key for use in a cache.
- return "\n".join(sorted(import_strs))
- @compatibility(is_backward_compatible=True)
- def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
- # BC: attribute name was changed from `code` to `_code` to facilitate
- # making `code` into a property and adding a docstring to it
- fn_src = body.get("_code") or body["code"]
- forward = _forward_from_src(import_block + fn_src, {})
- return _deserialize_graph_module(forward, body)
- @compatibility(is_backward_compatible=True)
- def reduce_package_graph_module(
- importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
- ) -> torch.nn.Module:
- forward = importer.import_module(generated_module_name).forward
- return _deserialize_graph_module(forward, body)
- @compatibility(is_backward_compatible=True)
- def reduce_deploy_graph_module(
- importer: PackageImporter, body: Dict[Any, Any], import_block: str
- ) -> torch.nn.Module:
- ns = {}
- ns["__builtins__"] = importer.patched_builtins
- fn_src = body.get("_code")
- assert fn_src is not None
- forward = _forward_from_src(import_block + fn_src, ns)
- return _deserialize_graph_module(forward, body)
- # We create a dummy class here because symbolic_trace pulls the forward()
- # function off of the class, rather than the instance. This class is used
- # in _deserialize_graph_module() below.
- class _CodeOnlyModule(torch.nn.Module):
- def __init__(self, body):
- super().__init__()
- self.__dict__ = body
- def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module:
- """
- Deserialize a GraphModule given the dictionary of the original module,
- using the code to reconstruct the graph. We delete the actual graph before
- saving the dictionary so that changes to the in-memory graph format do not
- get serialized.
- """
- # Try to retrieve the forward source in a backward-compatible way
- _CodeOnlyModule.forward = forward
- tracer_cls = body.get("_tracer_cls")
- if tracer_cls is None:
- from ._symbolic_trace import Tracer
- tracer_cls = Tracer
- graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
- # This is a workaround for a mypy linter issue related to
- # passing base class as an argument - https://github.com/python/mypy/issues/5865.
- cls_tracer: Any = tracer_cls
- class KeepModules(cls_tracer):
- # we shouldn't trace into any of the submodules,
- # because they were not traced in the original GraphModule
- def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
- return True
- com = _CodeOnlyModule(body)
- tracer_extras = body.get("_tracer_extras", {})
- graph = KeepModules().trace(com, **tracer_extras)
- # Manually set Tracer class on the reconstructed Graph, to avoid
- # referencing the private local subclass KeepModules.
- graph._tracer_cls = tracer_cls
- from ._lazy_graph_module import _make_graph_module
- gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls)
- # The GraphModule constructor only retains attributes referenced by the graph.
- # In this case, our goal is return a GraphModule as close to identical as the one
- # put into the package. If any additional attributes were present in body,
- # we should keep them.
- for k, v in body.items():
- if not hasattr(gm, k):
- setattr(gm, k, v)
- return gm
- # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
- # This installs empty Modules where none exist yet if they are subpaths of target
- def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
- *prefix, field = target.split(".")
- for item in prefix:
- f = getattr(from_module, item)
- t = getattr(to_module, item, None)
- if f is t:
- # we have already installed one of its parents
- # (e.g. target = root.linear.weight, but we have already installed root.linear)
- # once we install a parent, we no longer need to copy the children
- # since all the needed properties will already be present
- return
- if t is None:
- t = torch.nn.Module()
- setattr(to_module, item, t)
- from_module, to_module = f, t
- orig = getattr(from_module, field)
- # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
- # So, we register it as a named buffer in the target module.
- if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
- to_module.register_buffer(field, orig)
- else:
- setattr(to_module, field, orig)
- # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
- # This installs empty Modules where none exist yet if they are subpaths of target
- def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
- *prefix, field = target.split(".")
- for item in prefix:
- t = getattr(to_module, item, None)
- if t is None:
- t = torch.nn.Module()
- setattr(to_module, item, t)
- to_module = t
- # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
- # So, we register it as a named buffer in the target module.
- if isinstance(from_obj, torch.Tensor) and not isinstance(
- from_obj, torch.nn.Parameter
- ):
- to_module.register_buffer(field, from_obj)
- else:
- setattr(to_module, field, from_obj)
- class _WrappedCall:
- def __init__(self, cls, cls_call):
- self.cls = cls
- self.cls_call = cls_call
- # Previously, if an error occurred when valid
- # symbolically-traced code was run with an invalid input, the
- # user would see the source of the error as coming from
- # `File "<eval_with_key_N">`, where N is some number. We use
- # this function to generate a more informative error message. We
- # return the traceback itself, a message explaining that the
- # error occurred in a traced Module's generated forward
- # function, and five lines of context surrounding the faulty
- # line
- @staticmethod
- def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
- # auxiliary variables (for readability)
- err_lineno = frame_summary.lineno
- assert err_lineno is not None
- line = frame_summary.line
- assert line is not None
- err_line_len = len(line)
- all_src_lines = linecache.getlines(frame_summary.filename)
- # constituent substrings of the error message
- tb_repr = torch._dynamo.disable(traceback.format_exc)()
- custom_msg = (
- "Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:"
- )
- before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
- marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
- # joined message
- return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
- def __call__(self, obj, *args, **kwargs):
- try:
- if self.cls_call is not None:
- return self.cls_call(obj, *args, **kwargs)
- else:
- return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
- except Exception as e:
- assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = (
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
- ) # type: ignore[arg-type]
- if "eval_with_key" in topmost_framesummary.filename:
- print(
- _WrappedCall._generate_error_message(topmost_framesummary),
- file=sys.stderr,
- )
- raise e.with_traceback(None) # noqa: B904
- else:
- raise e
- @compatibility(is_backward_compatible=True)
- class GraphModule(torch.nn.Module):
- """
- GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
- ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
- from that ``graph``.
- .. warning::
- When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
- regenerated. However, if you edit the contents of the ``graph`` without reassigning
- the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
- code.
- """
- def __new__(cls: "Type[GraphModule]", *args, **kwargs):
- # each instance of a graph module needs its own forward method
- # so create a new singleton class for each instance.
- # it is a subclass of the user-defined class, the only difference
- # is an extra layer to install the forward method
- # address issue described at https://github.com/pytorch/pytorch/issues/63883
- # in other words, traverse class hierarchy to fix the redundant class definition problem
- for t in cls.__mro__:
- c = t.__qualname__.split(".")[-1]
- if c != "GraphModuleImpl":
- cls = t
- break
- class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
- pass
- return super().__new__(GraphModuleImpl)
- @compatibility(is_backward_compatible=True)
- def __init__(
- self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: Graph,
- class_name: str = "GraphModule",
- ):
- """
- Construct a GraphModule.
- Args:
- root (Union[torch.nn.Module, Dict[str, Any]):
- ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
- In the case that ``root`` is a Module, any references to Module-based objects (via qualified
- name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
- within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
- In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
- looked up directly in the dict's keys. The object mapped to by the Dict will be copied
- over into the appropriate place within the GraphModule's module hierarchy.
- graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
- class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
- error messages will report as originating from ``GraphModule``. It may be helpful to set this
- to ``root``'s original name or a name that makes sense within the context of your transform.
- """
- super().__init__()
- self.__class__.__name__ = class_name
- if isinstance(root, torch.nn.Module):
- if hasattr(root, "training"):
- self.training = root.training
- # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
- if isinstance(root, _CodeOnlyModule):
- for k, _ in root.named_children():
- _copy_attr(root, self, k)
- for k, _ in root.named_buffers():
- _copy_attr(root, self, k)
- for k, _ in root.named_parameters():
- _copy_attr(root, self, k)
- for node in graph.nodes:
- if node.op in ["get_attr", "call_module"]:
- assert isinstance(node.target, str)
- _copy_attr(root, self, node.target)
- elif isinstance(root, dict):
- targets_to_copy = []
- for node in graph.nodes:
- if node.op in ["get_attr", "call_module"]:
- assert isinstance(node.target, str)
- if node.target not in root:
- raise RuntimeError(
- "Node "
- + str(node)
- + " referenced target "
- + node.target
- + " but that target was not provided in ``root``!"
- )
- targets_to_copy.append(node.target)
- # Sort targets in ascending order of the # of atoms.
- # This will ensure that less deeply nested attributes are assigned
- # before more deeply nested attributes. For example, foo.bar
- # will be assigned before foo.bar.baz. Otherwise, we might assign
- # the user-provided ``foo.bar`` and wipe out the previously-assigned
- # ``foo.bar.baz``
- targets_to_copy.sort(key=lambda t: t.count("."))
- for target_to_copy in targets_to_copy:
- _assign_attr(root[target_to_copy], self, target_to_copy)
- else:
- raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
- self.graph = graph
- # Store the Tracer class responsible for creating a Graph separately as part of the
- # GraphModule state, except when the Tracer is defined in a local namespace.
- # Locally defined Tracers are not pickleable. This is needed because torch.package will
- # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
- # to re-create the Graph during deserialization.
- self._tracer_cls = None
- if (
- self.graph._tracer_cls
- and "<locals>" not in self.graph._tracer_cls.__qualname__
- ):
- self._tracer_cls = self.graph._tracer_cls
- self._tracer_extras = {}
- if self.graph._tracer_extras:
- self._tracer_extras = self.graph._tracer_extras
- # Dictionary to store metadata
- self.meta: Dict[str, Any] = {}
- self._replace_hook = None
- self._create_node_hooks: List[Callable] = []
- self._erase_node_hooks: List[Callable] = []
- # TorchScript breaks trying to compile the graph setter because of the
- # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
- #
- # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
- __jit_unused_properties__ = ["graph"]
- @property
- def graph(self) -> Graph:
- """
- Return the ``Graph`` underlying this ``GraphModule``
- """
- return self._graph
- @graph.setter
- def graph(self, g: Graph) -> None:
- """
- Set the underlying ``Graph`` for this ``GraphModule``. This will internally
- recompile the ``GraphModule`` so that the generated ``forward()`` function
- corresponds to ``g``
- """
- assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
- self._graph = g
- g.owning_module = self
- self.recompile()
- @compatibility(is_backward_compatible=False)
- def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
- """Dumps out module to ``folder`` with ``module_name`` so that it can be
- imported with ``from <folder> import <module_name>``
- Args:
- folder (Union[str, os.PathLike]): The folder to write the code out to
- module_name (str): Top-level name to use for the ``Module`` while
- writing out the code
- """
- folder = Path(folder)
- Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / "state_dict.pt")
- tab = " " * 4
- custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
- model_str = f"""
- import torch
- {custom_builtins}
- from torch.nn import *
- class {module_name}(torch.nn.Module):
- def __init__(self):
- super().__init__()
- """
- def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
- safe_reprs = [
- nn.Linear,
- nn.Conv1d,
- nn.Conv2d,
- nn.Conv3d,
- nn.BatchNorm1d,
- nn.BatchNorm2d,
- nn.BatchNorm3d,
- ]
- if type(module) in safe_reprs:
- return f"{module.__repr__()}"
- else:
- return None
- blobified_modules = []
- for module_name, module in self.named_children():
- module_str = _gen_model_repr(module_name, module)
- if module_str is None:
- module_file = folder / f"{module_name}.pt"
- torch.save(module, module_file)
- blobified_modules.append(module_name)
- module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
- module_str = f"torch.load(r'{module_file}') # {module_repr}"
- model_str += f"{tab*2}self.{module_name} = {module_str}\n"
- for buffer_name, buffer in self._buffers.items():
- if buffer is None:
- continue
- model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
- for param_name, param in self._parameters.items():
- if param is None:
- continue
- model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
- model_str += (
- f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
- )
- model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / "module.py"
- module_file.write_text(model_str)
- init_file = folder / "__init__.py"
- init_file.write_text("from .module import *")
- if len(blobified_modules) > 0:
- warnings.warn(
- "Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}"
- )
- @compatibility(is_backward_compatible=True)
- def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
- """
- Adds the given submodule to ``self``.
- This installs empty Modules where none exist yet if they are
- subpaths of ``target``.
- Args:
- target: The fully-qualified string name of the new submodule
- (See example in ``nn.Module.get_submodule`` for how to
- specify a fully-qualified string.)
- m: The submodule itself; the actual object we want to
- install in the current Module
- Return:
- bool: Whether or not the submodule could be inserted. For
- this method to return True, each object in the chain
- denoted by ``target`` must either a) not exist yet,
- or b) reference an ``nn.Module`` (not a parameter or
- other attribute)
- """
- *prefix, field = target.split(".")
- mod: torch.nn.Module = self
- for item in prefix:
- submod = getattr(mod, item, None)
- if submod is None:
- submod = torch.nn.Module()
- setattr(mod, item, submod)
- if not isinstance(submod, torch.nn.Module):
- return False
- mod = submod
- mod.add_module(field, m)
- return True
- @compatibility(is_backward_compatible=True)
- def delete_submodule(self, target: str) -> bool:
- """
- Deletes the given submodule from ``self``.
- The module will not be deleted if ``target`` is not a valid
- target.
- Args:
- target: The fully-qualified string name of the new submodule
- (See example in ``nn.Module.get_submodule`` for how to
- specify a fully-qualified string.)
- Returns:
- bool: Whether or not the target string referenced a
- submodule we want to delete. A return value of ``False``
- means that the ``target`` was not a valid reference to
- a submodule.
- """
- atoms = target.split(".")
- path, target_submod = atoms[:-1], atoms[-1]
- mod: torch.nn.Module = self
- # Get the parent module
- for item in path:
- if not hasattr(mod, item):
- return False
- mod = getattr(mod, item)
- if not isinstance(mod, torch.nn.Module):
- return False
- if not hasattr(mod, target_submod):
- return False
- if not isinstance(getattr(mod, target_submod), torch.nn.Module):
- return False
- delattr(mod, target_submod)
- return True
- @compatibility(is_backward_compatible=True)
- def delete_all_unused_submodules(self) -> None:
- """
- Deletes all unused submodules from ``self``.
- A Module is considered "used" if any one of the following is
- true:
- 1. It has children that are used
- 2. Its forward is called directly via a ``call_module`` node
- 3. It has a non-Module attribute that is used from a
- ``get_attr`` node
- This method can be called to clean up an ``nn.Module`` without
- manually calling ``delete_submodule`` on each unused submodule.
- """
- used: List[str] = []
- for node in self.graph.nodes:
- if node.op == "call_module" or node.op == "get_attr":
- # A list of strings representing the different parts
- # of the path. For example, `foo.bar.baz` gives us
- # ["foo", "bar", "baz"]
- fullpath = node.target.split(".")
- # If we're looking at multiple parts of a path, join
- # join them with a dot. Otherwise, return that single
- # element without doing anything to it.
- def join_fn(x: str, y: str) -> str:
- return ".".join([x, y] if y else [x])
- # Progressively collect all the names of intermediate
- # modules. For example, if we have the target
- # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
- # `foo.bar.baz` to the list.
- used.extend(itertools.accumulate(fullpath, join_fn))
- # For a `call_module` node, also register all recursive submodules
- # as used
- if node.op == "call_module":
- try:
- submod = self.get_submodule(node.target)
- for submod_name, _ in submod.named_modules():
- if submod_name != "":
- used.append(".".join([node.target, submod_name]))
- except AttributeError:
- # Node referenced nonexistent submodule, don't need to
- # worry about GCing anything
- pass
- to_delete = [name for name, _ in self.named_modules() if name not in used]
- for name in to_delete:
- self.delete_submodule(name)
- @property
- def code(self) -> str:
- """
- Return the Python code generated from the ``Graph`` underlying this
- ``GraphModule``.
- """
- if not hasattr(self, "_code"):
- raise RuntimeError(
- "Code has not been generated! Please report a bug to PyTorch"
- )
- return self._code
- @compatibility(is_backward_compatible=True)
- def recompile(self) -> PythonCode:
- """
- Recompile this GraphModule from its ``graph`` attribute. This should be
- called after editing the contained ``graph``, otherwise the generated
- code of this ``GraphModule`` will be out of date.
- """
- if isinstance(self._graph._codegen, _PyTreeCodeGen):
- self._in_spec = self._graph._codegen.pytree_info.in_spec
- self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module="self")
- self._code = python_code.src
- self._lineno_map = python_code._lineno_map
- cls = type(self)
- co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
- cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
- # Determine whether this class explicitly defines a __call__ implementation
- # to wrap. If it does, save it in order to have wrapped_call invoke it.
- # If it does not, wrapped_call can use a dynamic call to super() instead.
- # In most cases, super().__call__ should be torch.nn.Module.__call__.
- # We do not want to hold a reference to Module.__call__ here; doing so will
- # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
- cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if "_wrapped_call" not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
- def call_wrapped(self, *args, **kwargs):
- return self._wrapped_call(self, *args, **kwargs)
- cls.__call__ = call_wrapped # type: ignore[method-assign]
- return python_code
- # Passing Tracer as argument allows subclasses extending fx.GraphModule
- # define their own Tracer (extending fx.Tracer).
- def __reduce_deploy__(self, importer: Importer):
- dict_without_graph = self.__dict__.copy()
- dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
- del dict_without_graph["_graph"]
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, importer)
- return (reduce_deploy_graph_module, (dict_without_graph, import_block))
- def __reduce_package__(self, exporter: PackageExporter):
- dict_without_graph = self.__dict__.copy()
- dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
- del dict_without_graph["_graph"]
- generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, exporter.importer)
- module_code = import_block + self.code
- exporter.save_source_string(generated_module_name, module_code)
- return (
- reduce_package_graph_module,
- (dict_without_graph, generated_module_name),
- )
- def __reduce__(self):
- """
- Serialization of GraphModule. We serialize only the generated code, not
- the underlying ``Graph``. This is because ``Graph`` does not have on-disk
- backward-compatibility guarantees, whereas Python source code does.
- On the deserialization side, we symbolically trace through the generated
- code to regenerate the underlying ``Graph``
- """
- dict_without_graph = self.__dict__.copy()
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, sys_importer)
- del dict_without_graph["_graph"]
- return (reduce_graph_module, (dict_without_graph, import_block))
- def _deepcopy_init(self):
- return GraphModule.__init__
- # because __reduce__ is defined for serialization,
- # we need to define deepcopy otherwise it will call __reduce__
- # and cause symbolic tracing to occur every time we try to copy the object
- def __deepcopy__(self, memo):
- res = type(self).__new__(type(self))
- memo[id(self)] = res
- fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
- self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
- # hooks are lost during `GraphModule.__init__`, so we need to copy over
- # them explicitly, note right now we are only copying state_dict related
- # hooks, to reduce bc-related issues, we can copy forward/backward related
- # hooks in the future as well if needed
- extra_preserved_attrs = [
- "_state_dict_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
- "_replace_hook",
- "_create_node_hooks",
- "_erase_node_hooks"
- ]
- for attr in extra_preserved_attrs:
- if attr in self.__dict__:
- setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
- res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
- if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
- for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
- setattr(res, attr_name, attr)
- return res
- def __copy__(self):
- from ._lazy_graph_module import _make_graph_module
- res = _make_graph_module(self, self.graph)
- res.meta = getattr(self, "meta", {})
- return res
- @compatibility(is_backward_compatible=False)
- def print_readable(self, print_output=True, include_stride=False, include_device=False):
- """
- Return the Python code generated for current GraphModule and its children GraphModules
- """
- verbose_python_code = self._graph.python_code(
- root_module="self", verbose=True, include_stride=include_stride, include_device=include_device
- )
- module_code = verbose_python_code.src
- module_code = module_code.lstrip("\n")
- module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
- module_code = _addindent(module_code, 4)
- submodule_code_list = [""]
- for submodule in self.children():
- if isinstance(submodule, GraphModule):
- submodule_code_list.append(submodule.print_readable(print_output=False))
- submodule_code = "\n".join(submodule_code_list)
- submodule_code = _addindent(submodule_code, 4)
- output = module_code + submodule_code
- if print_output:
- print(module_code + submodule_code)
- return output
- def __str__(self) -> str:
- orig_str = super().__str__()
- print_readable_reminder = (
- "# To see more debug info, please use `graph_module.print_readable()`"
- )
- return "\n".join([orig_str, self._code, print_readable_reminder])
- def _replicate_for_data_parallel(self):
- new_gm = self.__copy__()
- new_gm._is_replica = True
- return new_gm
- @contextlib.contextmanager
- def _set_replace_hook(self, f):
- """
- Takes a callable which will be called everytime when we replace a node
- to a new node, or change the node's name. Callable takes three arguments:
- the old node we're changing, and NAME of the new node, followed by the
- user node which consumes the old node to be replaced.
- """
- assert callable(f), "Replace hook must be a callable."
- prev, self._replace_hook = self._replace_hook, f
- try:
- yield
- finally:
- self._replace_hook = prev
- def _register_create_node_hook(self, f):
- """
- Takes a callable which will be called after we create a new node. The
- callable takes the newly created node as input and returns None.
- """
- assert callable(f), "create_node hook must be a callable."
- self._create_node_hooks.append(f)
- def _unregister_create_node_hook(self, f):
- """
- Takes a callable which was previously registered to be called after we create a node.
- This function will unregister that callable so it is no longer invoked on node creation.
- """
- assert callable(f), "create_node hook must be a callable."
- self._create_node_hooks.remove(f)
- def _register_erase_node_hook(self, f):
- """
- Takes a callable which will be called after we erase a node. The
- callable takes the node that is being erased as input and returns None.
- """
- assert callable(f), "erase_node hook must be a callable."
- self._erase_node_hooks.append(f)
- def _unregister_erase_node_hook(self, f):
- """
- Takes a callable which was previously registered to be called after we erase a node.
- This function will unregister that callable so it is no longer invoked on node erasure.
- """
- assert callable(f), "erase_node hook must be a callable."
- self._erase_node_hooks.remove(f)
- # workarounds for issues in __torch_function__
- # WAR for __torch_function__ not handling tensor lists,
- # fix is in https://github.com/pytorch/pytorch/pull/34725
- # orig_cat = torch.cat
- # def patched_cat(*args, **kwargs):
- # tensors = args[0]
- # for t in tensors:
- # if isinstance(t, Proxy):
- # return t.__torch_function__(patched_cat, (), args, kwargs)
- # return orig_cat(*args, **kwargs)
- # patched_cat.__module__ = 'torch'
- # patched_cat.__name__ = 'cat'
- # torch.cat = patched_cat
|