| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976 |
- # mypy: ignore-errors
- import functools
- import inspect
- import itertools
- import types
- from contextlib import contextmanager, nullcontext
- from typing import Any, Dict, List
- import torch.nn
- from .. import trace_rules, variables
- from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
- from ..guards import GuardBuilder, install_guard
- from ..mutation_guard import GenerationTracker
- from ..source import (
- AttrSource,
- FSDPNNModuleSource,
- GetItemSource,
- NNModuleSource,
- NotNNModuleSource,
- )
- from ..utils import (
- get_custom_getattr,
- get_fake_value,
- is_lazy_module,
- is_namedtuple,
- is_safe_constant,
- istensor,
- istype,
- nnmodule_has_hooks,
- object_has_getattribute,
- proxy_args_kwargs,
- set_example_value,
- )
- from .base import MutableLocal, typestr, VariableTracker
- from .functions import invoke_and_store_as_constant
- from .lists import SliceVariable
- from .user_defined import UserDefinedObjectVariable
- def initialize_lazy_module(tx, mod, args, kwargs):
- """
- Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.
- Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially
- useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook
- by the time we trace __call__ and thus no graph-break for lazy allowed modules.
- """
- if hasattr(mod, "_initialize_hook"):
- def convert_to_fake(x):
- if is_namedtuple(x):
- return type(x)(*(convert_to_fake(elem) for elem in x))
- elif isinstance(x, dict):
- return {k: convert_to_fake(v) for k, v in x.items()}
- elif isinstance(x, (list, tuple, set)):
- return type(x)(convert_to_fake(elem) for elem in x)
- elif isinstance(x, torch.fx.Proxy):
- return get_fake_value(x.node, tx)
- else:
- return x
- proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
- fake_args = [convert_to_fake(arg) for arg in proxy_args]
- fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()}
- mod._infer_parameters(mod, fake_args, fake_kwargs)
- def cleanup_source_for_nn_module_stack(source):
- # TODO(anijain2305, export-team) This is a bad hack to fix the nn module
- # fully_qualified_name to work with export/unflatten. It converts
- # mod._modules['net1'] to mod.net1.
- # This type of source occurs when we use UnspecializedNNModule variable
- # because unspecialized nn module variable inlines module __getattr__ calls.
- # For export, we rely heavily on NNModuleVariable and do not support
- # UnspecializedNNModule. But there is one case where this gets exposed -
- # Pippy. Pippy uses export/unflatten (an export feature) and also
- # monkepatches the `forward` method of a mod that forces Dynamo to use
- # UnspecializedNNModule. Therefore, we will need proper work to retain the
- # nn module stack when we let export rely on UnspecializedNNModule variable.
- # This does not work if we have recursively UnspecializedNNModule variables
- # e.g. mod._modules['net1']._modules['net2']. This is unlikely to happen in
- # Pippy so the hotfix is enough for Pippy.
- if (
- isinstance(source, GetItemSource)
- and isinstance(source.base, AttrSource)
- and isinstance(source.base.base, NNModuleSource)
- and source.base.member == "_modules"
- ):
- return AttrSource(source.base.base, source.index)
- return source
- @contextmanager
- def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
- source_for_nn_module_stack = cleanup_source_for_nn_module_stack(source)
- fully_qualified_name = source_for_nn_module_stack.name()
- try:
- tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__)
- yield
- finally:
- del tx.nn_module_stack[module_key]
- def guard_to_detect_forward_monkeypatching(source, mod):
- # Users sometimes patch the forward method of a nn module instance to
- # perform optimizations like quantization. Though this is not a good
- # software practice, but python allows this and Dynamo needs to detect
- # this patching.
- #
- # One way to do this is to add an ID_MATCH guard on every function
- # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But
- # this increased guard overhead by around 20%.
- #
- # To keep the guard overhead down, we just guard on the `forward` being
- # not present in the mod __dict__. The common case of patching forward
- # method adds `forward` in the instance __dict__, whereas the unpatched
- # `forward` sits in the type(mod).__dict__
- if source:
- if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]):
- # Monkeypatched forward method, add an ID_MATCH guard on forward function
- fwd = mod.__dict__["forward"]
- forward_source = AttrSource(source, "forward")
- if type(fwd) is types.MethodType:
- forward_source = AttrSource(forward_source, "__func__")
- install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH))
- else:
- # Common case - check that the forward key is absent in mod __dict__
- install_guard(
- source.make_guard(
- functools.partial(
- GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward"
- )
- )
- )
- class NNModuleVariable(VariableTracker):
- _nonvar_fields = {
- "module_type",
- "module_key",
- "module",
- *VariableTracker._nonvar_fields,
- }
- def __init__(
- self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
- ):
- super().__init__(**kwargs)
- self.module_type = module_type
- self.module_key = module_key
- self.module = module
- assert self.source
- def python_type(self):
- return self.module_type
- def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
- return
- def unpack_var_sequence(self, tx):
- # implement list/iter/tuple/etc calls
- base = tx.output.get_submodule(self.module_key)
- if isinstance(base, torch.nn.ModuleDict):
- result = []
- for name, submod in base.items():
- name_var = variables.ConstantVariable.create(name)
- tx.output.register_attr_or_module(
- submod,
- self.module_key,
- name,
- source=NNModuleSource(GetItemSource(self.source, name)),
- )
- result.append(name_var)
- return result
- assert isinstance(
- base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
- ), typestr(base)
- assert self.source
- result = []
- for idx, submod in enumerate(base):
- result.append(
- tx.output.register_attr_or_module(
- submod,
- self.module_key,
- idx,
- source=NNModuleSource(GetItemSource(self.source, idx)),
- )
- )
- return result
- def call_hasattr(self, tx, name: str) -> "VariableTracker":
- mod = tx.output.get_submodule(self.module_key)
- result = hasattr(mod, name)
- install_guard(
- NNModuleSource(AttrSource(self.source, name)).make_guard(
- GuardBuilder.HASATTR
- )
- )
- return variables.ConstantVariable.create(result)
- def is_training(self, tx):
- mod = tx.output.get_submodule(self.module_key)
- return getattr(mod, "training", False)
- def convert_to_unspecialized(self, tx):
- """Restart analysis treating this module as an UnspecializedNNModuleVariable"""
- mod = tx.output.get_submodule(self.module_key)
- GenerationTracker.tag(mod)
- # Mark the class dynamic unless its module initialization
- if tx.f_code.co_name != "__init__":
- GenerationTracker.mark_class_dynamic(type(mod))
- raise UnspecializeRestartAnalysis
- def has_key_in_generic_dict(self, tx, key):
- base = tx.output.get_submodule(self.module_key)
- if object_has_getattribute(base):
- unimplemented("NNModuleVariable with custom __getattribute__")
- if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
- mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
- return not isinstance(mutated_attr, variables.DeletedVariable)
- base_dict = object.__getattribute__(base, "__dict__")
- return key in base_dict
- def _custom_getattr_fallback(self, base, tx, name, options):
- """Check for a __getattr__ and handle it specially if it is implemented"""
- if object_has_getattribute(base):
- unimplemented("torch.nn.Module with a custom __getattribute__ defined")
- getattr_fn = get_custom_getattr(base)
- if getattr_fn is None:
- return None
- if not isinstance(getattr_fn, types.FunctionType):
- unimplemented("torch.nn.Module with a non-function custom __getattr__")
- return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
- tx, [variables.ConstantVariable.create(name)], {}
- )
- def var_getattr(self, tx, name):
- from .builder import VariableBuilder
- if self.source:
- source = AttrSource(self.source, name)
- else:
- source = None
- base = tx.output.get_submodule(self.module_key)
- base_dict = object.__getattribute__(base, "__dict__")
- object_member = True
- all_class_attribute_names = set()
- for x in inspect.getmro(base.__class__):
- all_class_attribute_names.update(x.__dict__.keys())
- if not self.source:
- unimplemented("GETATTR with no source")
- if name == "__dict__":
- return variables.GetAttrVariable(self, name, source=source)
- if name in base_dict:
- subobj = base_dict[name]
- elif (
- "_modules" in base_dict
- and name in base_dict["_modules"]
- and name not in all_class_attribute_names
- ):
- subobj = base_dict["_modules"][name]
- elif "_parameters" in base_dict and name in base_dict["_parameters"]:
- subobj = base_dict["_parameters"][name]
- elif "_buffers" in base_dict and name in base_dict["_buffers"]:
- subobj = base_dict["_buffers"][name]
- else:
- try:
- subobj = inspect.getattr_static(base, name)
- object_member = False
- except AttributeError:
- # see if we can fallback to __getattr__, which is not checked by getattr_static
- result = self._custom_getattr_fallback(
- base=base, tx=tx, name=name, options={"source": source}
- )
- if result is not None:
- return result
- # if we can't find a __getattr__, just raise the AttributeError
- raise
- if name == "forward":
- guard_to_detect_forward_monkeypatching(self.source, base)
- if name == "__class__" and not object_member:
- return variables.UserDefinedClassVariable(base.__class__, source=source)
- if object_member:
- return VariableBuilder(tx, NNModuleSource(source))(subobj)
- else:
- if istype(subobj, property):
- if self.source:
- # Read the class attribute to reach the property
- source = AttrSource(AttrSource(self.source, "__class__"), name)
- # Get the getter function
- source = AttrSource(source, "fget")
- return variables.UserFunctionVariable(
- subobj.fget,
- source=source,
- ).call_function(tx, [(self)], {})
- elif istype(subobj, classmethod):
- return variables.UserMethodVariable(
- subobj.__func__,
- variables.UserDefinedObjectVariable(type(base)),
- source=source,
- )
- elif istype(subobj, staticmethod):
- return variables.UserFunctionVariable(
- subobj.__get__(base), source=source
- )
- elif istype(subobj, types.FunctionType):
- return variables.UserMethodVariable(subobj, self, source=source)
- elif is_safe_constant(subobj) or istensor(subobj):
- # Support possibly common cases of class members
- return VariableBuilder(tx, NNModuleSource(source))(subobj)
- else:
- unimplemented(
- f"class property {name} - {typestr(base)} {typestr(subobj)}"
- )
- return variables.GetAttrVariable(self, name, source=source)
- def call_function(
- self,
- tx,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- mod = tx.output.get_submodule(self.module_key)
- with record_nn_module_stack(self.module_key, self.source, tx, mod):
- is_lazy = is_lazy_module(mod)
- if (
- isinstance(mod, torch.nn.Sequential)
- and mod.__class__.forward is torch.nn.Sequential.forward
- ):
- if nnmodule_has_hooks(mod):
- # We do not want to unroll sequential if it has hooks, since evaporating it
- # will cause hooks to not fire!
- # This terminates and restart the tracing process
- self.convert_to_unspecialized(tx)
- # Unroll sequential
- assert (
- not is_lazy
- ), "Expected lazy sequential isn't a valid combination?"
- assert not kwargs
- (arg,) = args
- # TODO: Use named_children when it supports remove_duplicate=False.
- for child_name, submod in mod._modules.items():
- tx.call_function(
- tx.output.register_attr_or_module(
- submod,
- self.module_key,
- child_name,
- source=NNModuleSource(AttrSource(self.source, child_name)),
- ),
- [arg],
- {},
- )
- arg = tx.pop()
- return arg
- if is_lazy:
- # The module type will change after it is called
- if mod.cls_to_become is not None:
- self.module_type = mod.cls_to_become
- # The pre-hook runs to initialize the module shapes, then deletes itself. After this,
- # the module is more or less not lazy and can be treated as a normal module regardless of
- # is_allowed or other variations.
- initialize_lazy_module(tx, mod, args, kwargs)
- # If we are tracing the higher order op, we want Dynamo to step
- # inside the module call so that Dynamo can see the underlying
- # parameters and buffers and raise them as inputs to the graph.
- #
- # NB: torch.nn.utils.parametrize changes the class type of a
- # parametrized module such that its __module__ points to
- # "torch.nn.utils.parametrize".
- if (
- tx.output.is_root_tracer()
- and mod.__module__.startswith(("torch.nn.", "torch.ao."))
- and mod.__module__ != "torch.nn.utils.parametrize"
- ):
- if nnmodule_has_hooks(
- mod, check_forward_hooks=True, check_backward_hooks=True
- ):
- # End of fn, this bubbles up and restarts tracing.
- self.convert_to_unspecialized(tx)
- from .builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_module",
- self.module_key,
- *proxy_args_kwargs(args, kwargs),
- ),
- )
- else:
- assert self.source, (
- "Must provide a valid source in order to inline, "
- "since inlined function may have default args which must be guarded."
- )
- if isinstance(mod, torch.fx.GraphModule):
- # TODO: do we want to support __call__ for GM's?
- # If so at least some changes are needed, we don't allow inlining
- # the call_wrapped currently, and maybe other issues too
- fn = mod.forward
- fn_source = AttrSource(self.source, "forward")
- else:
- fn = mod._call_impl
- fn_source = AttrSource(self.source, "_call_impl")
- if istype(fn, types.MethodType):
- fn = fn.__func__
- fn_source = AttrSource(fn_source, "__func__")
- args = [self] + args
- else:
- assert istype(fn, types.FunctionType)
- return tx.inline_user_function_return(
- variables.UserFunctionVariable(fn, source=fn_source),
- args,
- kwargs,
- )
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- constant=False,
- ) -> "VariableTracker":
- from . import ConstantVariable, ListIteratorVariable, TupleVariable
- key = self.module_key
- module = tx.output.get_submodule(key)
- def generic_call_method_helper(name):
- # Helper function to put a `call_method` node in FX graph,
- # with nn.Module as the first arg.
- mod_proxy = tx.output.create_proxy(
- "get_attr",
- self.module_key,
- tuple(),
- {},
- )
- set_example_value(mod_proxy.node, module)
- proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
- from .builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method",
- name,
- args=(mod_proxy, *proxy_args),
- kwargs=proxy_kwargs,
- ),
- )
- if name in ["_call_impl", "_wrapped_call_impl"]:
- # Example: `self.layer.__call__(x)`
- # This is used for explicit calling `__call__` in a forward function.
- # Dynamo inlines `__call__`, includes hooks.
- return self.call_function(tx, args, kwargs)
- elif name == "forward":
- # Example: `self.layer.forward(x)`
- # This is used for explicit calling `forward` in a forward function.
- # Dynamo puts `call_method` node in FX, doesn't trigger hooks.
- with record_nn_module_stack(self.module_key, self.source, tx, module):
- return generic_call_method_helper(name)
- if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed(
- inspect.getfile(module.__class__._check_input_dim)
- ):
- return ConstantVariable.create(True)
- if name == "_get_item_by_idx":
- assert args[1].is_python_constant()
- assert isinstance(args[0], TupleVariable)
- mod_var = args[0].items[args[1].value]
- if isinstance(mod_var, UnspecializedNNModuleVariable):
- return mod_var
- key = mod_var.module_key
- submod = tx.output.get_submodule(key)
- return tx.output.register_attr_or_module(
- submod,
- key,
- key,
- source=NNModuleSource(GetItemSource(self.source, key)),
- )
- if constant:
- fn = getattr(module, name)
- name = f"{module.__class__.__name__}_{name}_result"
- return invoke_and_store_as_constant(tx, fn, name, args, kwargs)
- def assert_all_args_kwargs_const():
- if not all(
- x.is_python_constant() for x in itertools.chain(args, kwargs.values())
- ):
- unimplemented(f"non-const NNModule method {name}")
- def get_kwargs(*names):
- assert_all_args_kwargs_const()
- fn = getattr(module, name)
- bound_args = inspect.signature(fn).bind(
- *([x.as_python_constant() for x in args]),
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- bound_args.apply_defaults()
- bound_args = bound_args.arguments
- return {k: bound_args[k] for k in names}
- def wrap_values(items):
- result = []
- for name, submod in items:
- result.append(
- tx.output.register_attr_or_module(
- submod,
- key,
- name,
- source=NNModuleSource(gen_source(self.source, name)),
- )
- )
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- def named_embed(name, obj):
- return TupleVariable(
- [
- ConstantVariable.create(name),
- tx.output.register_attr_or_module(
- obj,
- key,
- name,
- source=NNModuleSource(gen_source(self.source, name)),
- ),
- ]
- )
- def gen_source(source, name):
- name_split = name.split(".")
- if name_split[0] == "":
- return source
- while len(name_split) > 0:
- x = name_split.pop(0)
- source = AttrSource(source, x)
- return source
- if name == "named_children":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
- assert not (args or kwargs)
- result = []
- for name, submod in module.named_children():
- result.append(named_embed(name, submod))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "named_parameters":
- tx.output.guard_on_key_order.add(
- AttrSource(self.source, "_parameters").name()
- )
- result = []
- for name, param in module.named_parameters(
- **get_kwargs("prefix", "recurse")
- ):
- result.append(named_embed(name, param))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "named_buffers":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
- result = []
- for name, buffer in module.named_buffers(
- **get_kwargs("prefix", "recurse", "remove_duplicate")
- ):
- result.append(named_embed(name, buffer))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "named_modules":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
- result = []
- for name, submod in module.named_modules(
- **get_kwargs("memo", "prefix", "remove_duplicate")
- ):
- result.append(named_embed(name, submod))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "children":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
- assert not (args or kwargs)
- return wrap_values(module.named_children())
- elif name == "modules":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
- return wrap_values(module.named_modules())
- elif name == "parameters":
- tx.output.guard_on_key_order.add(
- AttrSource(self.source, "_parameters").name()
- )
- return wrap_values(module.named_parameters(**get_kwargs("recurse")))
- elif name == "buffers":
- tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
- return wrap_values(module.named_buffers(**get_kwargs("recurse")))
- elif name == "keys":
- assert not (args or kwargs)
- result = []
- for name in module.keys():
- result.append(ConstantVariable.create(name))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "values":
- assert not (args or kwargs)
- return wrap_values(module.items())
- elif name == "items":
- assert not (args or kwargs)
- result = []
- for name, submod in module.items():
- result.append(named_embed(name, submod))
- return ListIteratorVariable(result, mutable_local=MutableLocal())
- elif name == "__len__":
- assert not (args or kwargs)
- return ConstantVariable.create(len(module))
- elif (
- name == "__contains__"
- and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
- and args
- and args[0].is_python_constant()
- ):
- return ConstantVariable.create(
- args[0].as_python_constant() in module._modules
- )
- elif name == "__getitem__":
- assert not kwargs and len(args) == 1
- builtin_supported = (
- torch.nn.ModuleDict.__getitem__,
- torch.nn.ModuleList.__getitem__,
- torch.nn.ParameterDict.__getitem__,
- torch.nn.ParameterList.__getitem__,
- torch.nn.Sequential.__getitem__,
- )
- if type(module).__getitem__ not in builtin_supported:
- assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
- key = args[0].as_python_constant()
- assert isinstance(key, (str, int))
- fn = getattr(module, name).__func__
- assert isinstance(fn, types.FunctionType)
- src = AttrSource(AttrSource(self.source, name), "__func__")
- return tx.inline_user_function_return(
- variables.UserFunctionVariable(fn, source=src),
- [self] + list(args),
- kwargs,
- )
- assert self.source
- if isinstance(args[0], SliceVariable):
- # Build a TupleVariable of NNModules
- result = []
- submods = []
- # Turn the slice into the list of integers
- keys = list(range(len(module)))[args[0].as_python_constant()]
- for idx, submod in enumerate(module[args[0].as_python_constant()]):
- key = keys[idx]
- src = NNModuleSource(GetItemSource(self.source, key))
- result.append(
- tx.output.register_attr_or_module(
- submod,
- key,
- source=src,
- )
- )
- submods.append(submod)
- new_module = torch.nn.Sequential(*submods)
- new_module_variable = tx.output.register_attr_or_module(
- new_module,
- f"{self}.__getitem__(slice)",
- source=NNModuleSource(
- GetItemSource(self.source, args[0].as_python_constant())
- ),
- )
- return new_module_variable
- from .tensor import SymNodeVariable
- if isinstance(args[0], SymNodeVariable):
- key = args[0].evaluate_expr(tx.output)
- else:
- key = args[0].as_python_constant()
- submod = module[key]
- return tx.output.register_attr_or_module(
- submod,
- self.module_key,
- key,
- source=NNModuleSource(GetItemSource(self.source, key)),
- )
- elif (
- name == "_get_abs_string_index"
- or (
- isinstance(module, torch.nn.modules.conv._ConvNd)
- and name == "_conv_forward"
- )
- or (
- isinstance(module, torch.nn.modules.conv._ConvTransposeNd)
- and name == "_output_padding"
- )
- ):
- # Inline the function
- fn = getattr(module, name).__func__
- fn_source = AttrSource(AttrSource(self.source, name), "__func__")
- return tx.inline_user_function_return(
- variables.UserFunctionVariable(fn, source=fn_source),
- [self] + args,
- kwargs,
- )
- # A loose heuristic, but seems to be generally good before we drop into the
- # manual handling of inputs
- elif (
- name in module.__class__.__dict__
- and callable(module.__class__.__dict__[name])
- and all(
- isinstance(x, variables.TensorVariable)
- for x in itertools.chain(args, kwargs.values())
- )
- ):
- return generic_call_method_helper(name)
- else:
- return super().call_method(tx, name, args, kwargs)
- class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
- _nonvar_fields = {
- "value_type",
- "is_state_mutated",
- *UserDefinedObjectVariable._nonvar_fields,
- }
- """
- The above class will specialize on the id() of a module and place
- parameters on the torch.fx.GraphModule. Giving one graph per
- module instance. This version treats nn.Modules() like other user
- defined objects and will pass parameters into the FX graph as inputs.
- Giving one graph per module class.
- """
- def __init__(self, value, **kwargs):
- if type(value) is torch.jit._script.RecursiveScriptModule:
- raise Unsupported(
- "ScriptModules aren't supported in UnspecializedNNModuleVariable"
- " becuase their .forward function isn't a static member of their type"
- )
- if "value_type" in kwargs:
- lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None)
- if type(value) is lazy_value_to_become:
- # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects)
- # and then later we called and mutated the LazyModule into a MaterializedModule.
- # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only
- # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation.
- kwargs["value_type"] = type(value)
- super().__init__(value=value, **kwargs)
- self.is_state_mutated = False
- @staticmethod
- @functools.lru_cache(None)
- def _nn_module_method_ids():
- # Allow __setattr__ to fall through to base class handler
- supported = {torch.nn.Module.__setattr__}
- return {
- id(x.__code__)
- for x in torch.nn.Module.__dict__.values()
- if hasattr(x, "__code__") and x not in supported
- }
- def unpack_var_sequence(self, tx):
- from .builder import VariableBuilder
- try:
- fn = inspect.getattr_static(self.value_type, "__iter__")
- except AttributeError as e:
- raise NotImplementedError from e
- if fn in (
- torch.nn.ModuleList.__iter__,
- torch.nn.ParameterList.__iter__,
- torch.nn.Sequential.__iter__,
- ):
- assert self.source
- return [
- VariableBuilder(tx, source=GetItemSource(self.source, idx))(item)
- for idx, item in enumerate(self.value)
- ]
- return super().unpack_var_sequence(tx)
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- mod = self.value
- # see comment on lazy module handling in NNModuleVariable.call_function for context
- if is_lazy_module(mod):
- if mod.cls_to_become is not None:
- self.value_type = mod.cls_to_become
- initialize_lazy_module(tx, mod, args, kwargs)
- name = "_call_impl"
- fn = getattr(self.value_type, name)
- if self.source:
- source = AttrSource(AttrSource(self.source, "__class__"), name)
- else:
- source = None
- guard_to_detect_forward_monkeypatching(self.source, mod)
- ctx = (
- record_nn_module_stack(str(id(mod)), self.source, tx, mod)
- if self.source
- else nullcontext()
- )
- with ctx:
- return variables.UserFunctionVariable(fn, source=source).call_function(
- tx, [self] + list(args), kwargs
- )
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from .builder import VariableBuilder
- if name in ["_call_impl", "_wrapped_call_impl"]:
- fn = getattr(self.value_type, name)
- if self.source:
- source = AttrSource(AttrSource(self.source, "__class__"), name)
- else:
- source = None
- return variables.UserFunctionVariable(fn, source=source).call_function(
- tx, [self] + list(args), kwargs
- )
- if name not in getattr(self.value, "__dict__", {}):
- try:
- method = inspect.getattr_static(type(self.value), name)
- except AttributeError:
- method = None
- if method is torch.nn.Module.parameters:
- assert not args or kwargs
- if tx.output.side_effects.has_pending_mutation(self):
- unimplemented("Module.parameters() with pending mutation")
- install_guard(
- self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
- )
- items = []
- for name, value in self.value.named_parameters():
- items.append(
- VariableBuilder(tx, AttrSource(self.source, name))(value)
- )
- return variables.ListIteratorVariable(
- items, mutable_local=MutableLocal()
- )
- elif isinstance(method, staticmethod):
- source = AttrSource(
- AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
- )
- return tx.inline_user_function_return(
- variables.UserFunctionVariable(method.__func__, source=source),
- args,
- kwargs,
- )
- if (
- hasattr(method, "__code__")
- and id(method.__code__) in self._nn_module_method_ids()
- ):
- unimplemented(f"UnspecializedNNModuleVariable missing {name}")
- # "_parameters" in self.value.__dict__ checks that module is initialized
- if name == "__setattr__" and "_parameters" in self.value.__dict__:
- # Record if mutations happens on parameters/buffers/modules. The
- # mutations on these are not tracked by base class
- # UserDefinedObject vt. This will be used later to graph break
- # on seeing a paramters() and family calls.
- # TODO(anijain2305) - This might not be needed if we let Dynamo
- # inline both getattr and setattr. In that case, it should see
- # the lowest level dicts - _parameters and family and
- # automatically track mutations on those. Investigate if that
- # can be done.
- attr_name = args[0].as_python_constant()
- value = args[1]
- # This is reverse engineered by looking at nn module __setattr__
- # logic.
- if (
- isinstance(value, variables.TensorVariable)
- and value.python_type() is torch.nn.Parameter
- ) or attr_name in self.value.__dict__["_parameters"]:
- # Handle parameters
- self.is_state_mutated = True
- elif attr_name in self.value.__dict__["_buffers"]:
- # Handle buffers
- self.is_state_mutated = True
- elif (
- isinstance(
- value,
- (
- variables.NNModuleVariable,
- variables.UnspecializedNNModuleVariable,
- ),
- )
- or attr_name in self.value.__dict__["_modules"]
- ):
- # Handle submodules
- self.is_state_mutated = True
- return super().call_method(tx, name, args, kwargs)
- class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
- """
- Tracing behavior: trace into submodules and treat them as Unspecialized, do not
- register parameters to the top-level, treat them as function inputs.
- Guards behavior: if 'skip_fsdp_guards', many guards that would be installed
- by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis
- that a user wrapping their model in FSDP(model) is already opting into a
- requirement to not modify internal model state, which would already break FSDP without
- compilation.
- """
- def __init__(self, value, **kwargs):
- source = kwargs.get("source", None)
- assert (
- source is not None
- ), "FSDPManagedNNModule depends on having an accurate source to control guarding."
- super().__init__(value=value, **kwargs)
- self.source = source
- @staticmethod
- def _wrap_source(source):
- if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)):
- if torch._dynamo.config.skip_fsdp_guards:
- return FSDPNNModuleSource(source)
- else:
- # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
- return NotNNModuleSource(source)
- else:
- return source
- def __setattr__(self, name: str, value: Any) -> None:
- if name == "source":
- value = FSDPManagedNNModuleVariable._wrap_source(value)
- return super().__setattr__(name, value)
|