| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073 |
- # mypy: allow-untyped-defs
- import collections
- import functools
- import inspect
- import sys
- import textwrap
- import types
- import warnings
- from typing import Dict, List, Set, Type
- import torch
- import torch._jit_internal as _jit_internal
- from torch._sources import fake_range
- from torch.jit._builtins import _find_builtin
- from torch.jit._check import AttributeTypeIsSupportedChecker
- from torch.jit._state import _add_script_class, _get_script_class, _python_cu
- from torch.jit.frontend import (
- get_class_properties,
- get_default_args,
- get_jit_class_def,
- get_jit_def,
- )
- from torch.nn import Module
- ScriptMethodStub = collections.namedtuple(
- "ScriptMethodStub", ("resolution_callback", "def_", "original_method")
- )
- PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_"))
- # TODO: there should be a more principled way of doing this.
- ignored_attributes = [
- "_version",
- "_parameters",
- "_buffers",
- "_non_persistent_buffers_set",
- "_backward_hooks",
- "_backward_pre_hooks",
- "_forward_hooks",
- "_forward_hooks_with_kwargs",
- "_forward_pre_hooks",
- "_forward_pre_hooks_with_kwargs",
- "_forward_hooks_always_called",
- "_state_dict_hooks",
- "_state_dict_pre_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
- "_modules",
- "_initializing",
- "dump_patches",
- ]
- def _compile_and_register_class(obj, rcb, qualified_name):
- script_class = _get_script_class(obj)
- if not script_class:
- ast = get_jit_class_def(obj, obj.__name__)
- defaults = torch.jit.frontend.get_default_args_for_class(obj)
- script_class = torch._C._jit_script_class_compile(
- qualified_name, ast, defaults, rcb
- )
- _add_script_class(obj, script_class)
- return script_class
- def make_stub(func, name):
- rcb = _jit_internal.createResolutionCallbackFromClosure(func)
- ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
- return ScriptMethodStub(rcb, ast, func)
- def make_stub_from_method(nn_module, method_name):
- func = getattr(nn_module, method_name)
- if isinstance(func, ScriptMethodStub):
- return func
- # Make sure the name present in the resulting AST will match the name
- # requested here. The only time they don't match is if you do something
- # like:
- # def _forward(self):
- # pass
- # forward = _forward
- # In this case, the actual function object will have the name `_forward`,
- # even though we requested a stub for `forward`.
- return make_stub(func, method_name)
- def make_stubs_from_exported_methods(mod):
- stubs = []
- for name in dir(mod):
- item = getattr(mod, name, None)
- if (
- _jit_internal.get_torchscript_modifier(item)
- is _jit_internal.FunctionModifiers.EXPORT
- ):
- stubs.append(make_stub_from_method(mod, name))
- return stubs
- def jit_ignored_properties(module):
- user_annotated_ignored_attributes = getattr(
- module, "__jit_ignored_attributes__", list()
- )
- def get_properties_names(module):
- return {k for k, v in vars(module).items() if isinstance(v, property)}
- properties = get_properties_names(type(module))
- user_annoted_ignored_properties = set()
- for ignored_attr in user_annotated_ignored_attributes:
- if ignored_attr in properties:
- user_annoted_ignored_properties.add(ignored_attr)
- return user_annoted_ignored_properties
- # base types that can be constants
- # in addition, tuples and lists of these base types are also considered constants
- # If you edit this list, then you also need to edit the handlers in
- # ConstantValue in jit/script/init.cpp
- _constant_types = (
- bool,
- float,
- int,
- str,
- type(None),
- torch.device,
- torch.layout,
- torch.dtype,
- )
- def _get_valid_constant(attr, v, owner_type):
- if isinstance(v, _constant_types):
- return v
- elif isinstance(v, (tuple, list)):
- return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
- constants = ", ".join(torch.typename(typ) for typ in _constant_types)
- raise TypeError(
- textwrap.dedent(
- f"""
- '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant.
- Valid constants are:
- 1. a nn.ModuleList
- 2. a value of type {{{constants}}}
- 3. a list or tuple of (2)
- """
- )
- )
- class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
- def __init__(self, source, filename, file_lineno, leading_whitespace_len):
- super().__init__(source, filename, file_lineno, leading_whitespace_len)
- def get_annotations(obj):
- if sys.version_info < (3, 10):
- return getattr(obj, "__annotations__", {})
- # In Python-3.10+ it is recommended to use inspect.get_annotations
- # See https://docs.python.org/3.10/howto/annotations.html
- # But also, in 3.10 annotations from base class are not inherited
- # by unannotated derived one, so they must be manually extracted
- annotations = inspect.get_annotations(obj)
- if annotations:
- return annotations
- def get_cls_annotations(cls):
- cls_annotations = inspect.get_annotations(cls)
- if cls_annotations:
- return cls_annotations
- for base in cls.__bases__:
- cls_annotations = get_cls_annotations(base)
- if cls_annotations:
- return cls_annotations
- return {}
- cls = obj if isinstance(obj, type) else type(obj)
- return get_cls_annotations(cls)
- def infer_concrete_type_builder(nn_module, share_types=True):
- """
- Build a ConcreteModuleTypeBuilder from an nn.Module.
- This ConcreteModuleType doesn't have a JIT type associated with it yet, it
- must be filled in by the caller.
- """
- concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
- if isinstance(nn_module, (torch.nn.ModuleDict)):
- concrete_type_builder.set_module_dict()
- if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
- concrete_type_builder.set_module_list()
- if isinstance(nn_module, (torch.nn.ParameterList)):
- concrete_type_builder.set_parameter_list()
- if isinstance(nn_module, (torch.nn.ParameterDict)):
- concrete_type_builder.set_parameter_dict()
- class_annotations = get_annotations(nn_module)
- if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
- class_annotations = {}
- # Get user-annotated ignored attributes.
- user_annotated_ignored_attributes = getattr(
- nn_module, "__jit_ignored_attributes__", list()
- )
- concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
- ignored_properties = jit_ignored_properties(nn_module)
- # try to infer the type from type annotation or from the object itself
- def infer_type(name, item):
- # The forward function from Module is special; never use this annotations; we
- # need to infer type directly using JIT. I originally wanted to write
- # this test as isinstance(class_annotations[name], Callable) but
- # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
- # is also true!
- inferred = False
- try:
- if (
- name in class_annotations
- and class_annotations[name]
- != torch.nn.Module.__annotations__["forward"]
- ):
- ann_to_type = torch.jit.annotations.ann_to_type(
- class_annotations[name], fake_range()
- )
- attr_type = torch._C.InferredType(ann_to_type)
- elif isinstance(item, torch.jit.Attribute):
- ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
- attr_type = torch._C.InferredType(ann_to_type)
- else:
- attr_type = torch._C._jit_try_infer_type(item)
- inferred = True
- except RuntimeError as re:
- raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re
- return attr_type, inferred
- added_names = set()
- for name, item in nn_module._parameters.items():
- if name in user_annotated_ignored_attributes:
- continue
- assert item is None or isinstance(item, torch.Tensor)
- attr_type, _ = infer_type(name, item)
- # We currently have the invariant in various places in our code
- # that parameters must be Tensors. However, the nn.Module API also
- # allows NoneType parameters. These parameters are not returned as
- # part of `parameters()` and its variants, but are available
- # through direct attribute access.
- concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
- added_names.add(name)
- for name, item in nn_module._buffers.items():
- if name in user_annotated_ignored_attributes:
- continue
- assert item is None or isinstance(item, torch.Tensor)
- attr_type, _ = infer_type(name, item)
- concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
- added_names.add(name)
- for name, item in nn_module._modules.items():
- if name in user_annotated_ignored_attributes:
- continue
- attr_type, _ = infer_type(name, item)
- if item is None:
- # Modules can be None. We don't have direct support for optional
- # Modules, so the register it as an NoneType attribute instead.
- concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
- continue
- if attr_type.success():
- assert attr_type.type().is_interface_type()
- # if the type can be inferred, it should be a module interface type
- sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
- attr_type.type()
- )
- else:
- # otherwise we get the concrete module type for item and add it to concrete_type
- sub_concrete_type = get_module_concrete_type(item, share_types)
- concrete_type_builder.add_module(name, sub_concrete_type)
- added_names.add(name)
- # populate constants_set
- constants_set = set(getattr(nn_module, "__constants__", ()))
- # Constants annotated via `Final[T]` rather than being added to `__constants__`
- for name, ann in class_annotations.items():
- if torch._jit_internal.is_final(ann):
- constants_set.add(name)
- for name in constants_set:
- if name in added_names:
- # TODO: We should really error in this case, but its bc-breaking so
- # we need to warn for at least one release
- if name in nn_module._modules:
- hint = "submodule"
- elif name in nn_module._buffers:
- hint = "buffer"
- elif name in nn_module._parameters:
- hint = "parameter"
- else:
- raise AssertionError(
- "added_names must be submodule, parameter, or buffer"
- )
- warnings.warn(
- f"'{name}' was found in ScriptModule constants, "
- f" but it is a non-constant {hint}. Consider removing it."
- )
- continue
- if not hasattr(nn_module, name):
- # TODO: We should really error in this case, but its bc-breaking so
- # we need to warn for at least one release
- warnings.warn(
- f"'{name}' was found in ScriptModule constants, "
- "but was not actually set in __init__. "
- "Consider removing it."
- )
- continue
- value = getattr(nn_module, name)
- concrete_type_builder.add_constant(
- name, _get_valid_constant(name, value, type(nn_module).__name__)
- )
- added_names.add(name)
- # populate overloads
- overloads = getattr(nn_module, "__overloads__", {})
- # update with any annotated overloads
- overloads.update(
- get_overload_name_mapping(
- get_overload_annotations(nn_module, ignored_properties)
- )
- )
- for name, overloaded_names in overloads.items():
- concrete_type_builder.add_overload(name, overloaded_names)
- for name, value in nn_module.__dict__.items():
- if name in ignored_attributes or name.startswith("__"):
- # Python objects have lots of random attributes attached to them;
- # PyTorch adds a few more. Prevent these from getting compiled.
- continue
- if name in user_annotated_ignored_attributes:
- continue
- if name in added_names:
- # Don't re-add anything we already added
- continue
- isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
- if isoverloadpacket:
- value = value.op
- # Handle Python function attributes
- if inspect.isfunction(value):
- try:
- scripted_fn = torch.jit.script(value)
- concrete_type_builder.add_function_attribute(
- name, torch._C._jit_try_infer_type(scripted_fn).type(), value
- )
- except Exception as e:
- # If we fail to script the function, it isn't a hard error.
- # Instead, we will add it to the list of attributes we failed
- # to convert, with the compilation error.
- hint = (
- "(This function exists as an attribute on the Python module, "
- "but we failed to compile it to a TorchScript function. "
- f"\nThe error stack is reproduced here:\n{e}"
- )
- concrete_type_builder.add_failed_attribute(name, hint)
- pass
- continue
- # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
- # a call to an aten function like torch.add)
- builtin_symbol_name = _find_builtin(value)
- if builtin_symbol_name:
- concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
- continue
- # Handle Script function attributes
- if isinstance(value, torch.jit.ScriptFunction):
- concrete_type_builder.add_function_attribute(
- name, torch._C._jit_try_infer_type(value).type(), value
- )
- continue
- # If we got here, this is a regular "data" attribute, add it to the concrete type
- attr_type, inferred = infer_type(name, value)
- if attr_type.success():
- concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
- else:
- # TODO: could add more detail here. For example, what the user should do
- # when the pytype is `list` or `NoneType`
- inferred_msg = (
- "Its type was inferred; try adding a type annotation for the attribute."
- if inferred
- else ""
- )
- additional_info = f"{attr_type.reason()}. {inferred_msg}"
- hint = (
- "(This attribute exists on the Python module, "
- f"but we failed to convert Python type: '{torch.typename(type(value))}' "
- f"to a TorchScript type. {additional_info})"
- )
- concrete_type_builder.add_failed_attribute(name, hint)
- # add hooks to concrete type
- for hook in nn_module._forward_hooks.values():
- concrete_type_builder.add_forward_hook(hook)
- for pre_hook in nn_module._forward_pre_hooks.values():
- concrete_type_builder.add_forward_pre_hook(pre_hook)
- return concrete_type_builder
- class ConcreteTypeStore:
- type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
- methods_compiled: Set[torch._C.ConcreteModuleType]
- def __init__(self):
- # Python module type => List[ConcreteModuleType)]
- self.type_store = {}
- # ConcreteTypes that have had their methods already compiled
- self.methods_compiled = set()
- def get_or_create_concrete_type(self, nn_module):
- """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible."""
- concrete_type_builder = infer_concrete_type_builder(nn_module)
- nn_module_type = type(nn_module)
- if nn_module_type not in self.type_store:
- self.type_store[nn_module_type] = []
- # Search the type store for an already-available JIT type
- known_types = self.type_store[nn_module_type]
- for known_type in known_types:
- if known_type.equals(concrete_type_builder):
- return known_type
- # We didn't find anything; generate a new JIT type from this concrete type
- concrete_type = concrete_type_builder.build()
- self.type_store[nn_module_type].append(concrete_type)
- return concrete_type
- concrete_type_store = ConcreteTypeStore()
- def create_methods_and_properties_from_stubs(
- concrete_type, method_stubs, property_stubs
- ):
- method_defs = [m.def_ for m in method_stubs]
- method_rcbs = [m.resolution_callback for m in method_stubs]
- method_defaults = [get_default_args(m.original_method) for m in method_stubs]
- property_defs = [p.def_ for p in property_stubs]
- property_rcbs = [p.resolution_callback for p in property_stubs]
- concrete_type._create_methods_and_properties(
- property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
- )
- def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
- hook_defs = [h.def_ for h in hook_stubs]
- hook_rcbs = [h.resolution_callback for h in hook_stubs]
- pre_hook_defs = [h.def_ for h in pre_hook_stubs]
- pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
- concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
- def get_module_concrete_type(nn_module, share_types=True):
- """
- Get a concrete type for nn_modules.
- If share_types is True, the concrete type is fetched from concrete_type_store.
- If it is False, a new concrete type is created without first searching concrete_type_store.
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- share_types = Whether to share underlying JIT types between modules (if possible).
- Returns:
- A concrete type for nn_module.
- """
- assert isinstance(nn_module, Module)
- if isinstance(nn_module, torch.jit.ScriptModule) and hasattr(
- nn_module, "_concrete_type"
- ):
- return nn_module._concrete_type
- if share_types:
- # Look into the store of cached JIT types
- concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
- else:
- # Get a concrete type directly, without trying to re-use an existing JIT
- # type from the type store.
- concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
- concrete_type_builder.set_poisoned()
- concrete_type = concrete_type_builder.build()
- return concrete_type
- def create_script_class(obj):
- """
- Create and return a RecursiveScriptClass instance from a Python object.
- Arguments:
- obj: A Python object.
- """
- qualified_class_name = _jit_internal._qualified_name(type(obj))
- rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
- # Script the type of obj if it hasn't already been scripted.
- _compile_and_register_class(type(obj), rcb, qualified_class_name)
- class_ty = _python_cu.get_class(qualified_class_name)
- # Create an empty torch._C.ScriptObject with the scripted type.
- cpp_object = torch._C._create_object_with_type(class_ty)
- # Copy all of the attributes over to the torch._C.ScriptObject.
- for name, value in obj.__dict__.items():
- cpp_object.setattr(name, value)
- # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance.
- return wrap_cpp_class(cpp_object)
- def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
- """
- Create a new ScriptModule from an nn.Module.
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
- share_types: Whether to share underlying JIT types between modules (if possible).
- NOTE: Only set to False this when we cannot guarantee type sharing will work
- correctly. This only happens today for traced modules, where the same
- module can produce different traced methods depending on the inputs.
- is_tracing: Whether this function is called during tracing or scripting. If tracing,
- we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
- attributes will be baked as constant in the tracing graph. In addition,
- this check significantly slows down the traced modules when the module size is big.
- """
- assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
- check_module_initialized(nn_module)
- concrete_type = get_module_concrete_type(nn_module, share_types)
- if not is_tracing:
- AttributeTypeIsSupportedChecker().check(nn_module)
- return create_script_module_impl(nn_module, concrete_type, stubs_fn)
- def create_script_module_impl(nn_module, concrete_type, stubs_fn):
- """
- Convert an nn.Module to a RecursiveScriptModule.
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- concrete_type: The fully initialized ConcreteType of the module.
- stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
- """
- cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
- method_stubs = stubs_fn(nn_module)
- property_stubs = get_property_stubs(nn_module)
- hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
- user_annotated_ignored_attributes = getattr(
- nn_module, "__jit_ignored_attributes__", list()
- )
- ignored_properties = jit_ignored_properties(nn_module)
- def init_fn(script_module):
- # Initialize the ScriptModule:
- # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
- for name in concrete_type.get_attributes().keys():
- orig_value = getattr(nn_module, name)
- orig_value = (
- orig_value.value
- if isinstance(orig_value, torch.jit.Attribute)
- else orig_value
- )
- cpp_module.setattr(name, orig_value)
- # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
- # recursively scripting them.
- for name, sub_concrete_type in concrete_type.get_modules():
- orig_value = getattr(nn_module, name)
- assert isinstance(
- orig_value, Module
- ), f"Expected Module but got {type(orig_value)}"
- module_type = sub_concrete_type.jit_type
- if isinstance(module_type, torch._C.InterfaceType):
- # use the interface inference rule to compile the module
- scripted = interface_script(module_type, orig_value)
- elif isinstance(orig_value, torch.jit.ScriptModule):
- scripted = orig_value
- else:
- # always reuse the provided stubs_fn to infer the methods to compile
- scripted = create_script_module_impl(
- orig_value, sub_concrete_type, stubs_fn
- )
- cpp_module.setattr(name, scripted)
- script_module._modules[name] = scripted
- # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
- # This ensures we can access these Python methods on the ScriptModule.
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
- unbound_function = getattr(nn_module, name).__func__
- bound_method = unbound_function.__get__(script_module)
- setattr(script_module, name, bound_method)
- elif concrete_type.is_ignored_attribute(name):
- setattr(script_module, name, item)
- # For convenience, attach the concrete type to the new ScriptModule
- script_module._concrete_type = concrete_type
- # Actually create the ScriptModule, initializing it with the function we just defined
- script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- # Compile methods if necessary
- if concrete_type not in concrete_type_store.methods_compiled:
- create_methods_and_properties_from_stubs(
- concrete_type, method_stubs, property_stubs
- )
- # Create hooks after methods to ensure no name collisions between hooks and methods.
- # If done before, hooks can overshadow methods that aren't exported.
- create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
- torch._C._run_emit_module_hook(cpp_module)
- concrete_type_store.methods_compiled.add(concrete_type)
- # Copy the forward hooks and pre-hooks to the new ScriptModule
- # to allow the hooks to be run from eager as ScriptFunctions
- for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
- script_module._forward_pre_hooks[idx] = fn
- for idx, fn in enumerate(script_module._c._get_forward_hooks()):
- script_module._forward_hooks[idx] = fn
- # Special handling so methods like __len__ work in script methods on classes derived from containers
- if (
- isinstance(
- nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
- )
- and "__len__" not in cpp_module._method_names()
- ):
- script_module.define(f"def __len__(self):\n return {len(nn_module)}\n")
- if (
- isinstance(nn_module, torch.nn.ModuleDict)
- and "__contains__" not in cpp_module._method_names()
- ):
- if len(nn_module.keys()):
- keys = repr(list(nn_module.keys()))
- script_module.define(
- f"def __contains__(self, key: str):\n return key in {keys}\n"
- )
- else:
- script_module.define("def __contains__(self, key: str):\n return False\n")
- # Make the compiled methods available to the Python ScriptModule class.
- for method_stub in method_stubs:
- if method_stub.original_method is None:
- # define()'d methods don't have an Python original_method, so we
- # don't need to do any Python re-wrapping stuff
- continue
- name = method_stub.original_method.__name__
- if name != method_stub.def_.name().name:
- # TODO: Why skip this? Because @torch.jit._overload_method will
- # mangle the name of the function.
- continue
- script_method = cpp_module._get_method(name)
- # Wrap the original to propagate docstrings and such.
- # TODO: we don't currently do this functions that are recursively
- # compiled, we should.
- wrapped_script_method = functools.wraps(method_stub.original_method)(
- script_method
- )
- # Add the methods to the script_module directly. This ensures they will
- # be found first when `name` is looked up (as opposed to the stubs or
- # nn.Module.forward)
- script_module.__dict__[name] = wrapped_script_method
- # Make module properties available on the Python ScriptModule class.
- for property_stub in property_stubs:
- property_name = property_stub.def_.name().name
- fget = cpp_module._get_method(property_stub.def_.getter_name().name)
- # Setter is optional, so it may not exist.
- setter_name = property_stub.def_.setter_name()
- fset = cpp_module._get_method(setter_name.name) if setter_name else None
- script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type]
- # copy over python methods to script module if they aren't defined on the script module
- # this is currently an internal api used only on module containers
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if (
- _jit_internal.get_torchscript_modifier(item)
- is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
- ):
- add_python_attr_to_scripted_model(script_module, nn_module, name)
- return script_module
- # We define shims of certain attributes on the RecursiveScriptModule to support
- # magic methods. To check if a script model defines an attribute we need
- # to also check that the attribute is not the shim
- def script_model_defines_attr(script_model, attr):
- script_attr = getattr(script_model, attr, None)
- if script_attr is None:
- return False
- default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None)
- if default_attr is None:
- return False
- return script_attr != default_attr
- def add_python_attr_to_scripted_model(script_model, orig, attr):
- if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
- setattr(script_model, attr, getattr(orig, attr))
- def get_overload_annotations(mod, jit_ignored_properties):
- # original function => [(mangled overload name, overload function)]
- overloads = {}
- for name in dir(type(mod)):
- if name in jit_ignored_properties:
- continue
- item = getattr(mod, name, None)
- if not callable(item):
- continue
- # builtin functions like repr() in python 2 do not have __module__ defined
- if hasattr(item, "__module__") and item.__module__ is not None:
- method_overloads = _jit_internal._get_overloaded_methods(
- item, mod.__class__
- )
- if method_overloads is None:
- continue
- if item.__func__ in method_overloads:
- raise RuntimeError(
- _jit_internal.get_overload_no_implementation_error_message(
- "method", item.__func__
- )
- )
- names = [name + "__" + str(i) for i in range(len(method_overloads))]
- overloads[item] = list(zip(names, method_overloads))
- return overloads
- def get_overload_name_mapping(overload_info):
- # Same format as __overloads__
- # original function => [overload names]
- overload_name_mappings: Dict[str, List[str]] = {}
- for orig_fn, overloads in overload_info.items():
- original_name = orig_fn.__name__
- if original_name not in overload_name_mappings:
- overload_name_mappings[original_name] = []
- for overload_name, _ in overloads:
- overload_name_mappings[original_name].append(overload_name)
- return overload_name_mappings
- def _check_no_signature(func):
- signature = torch.jit.annotations.get_signature(
- func, None, fake_range(), inspect.ismethod(func)
- )
- if signature is None:
- qual_name = _jit_internal._qualified_name(func)
- raise RuntimeError(
- f"Must explicitly add type annotations to overloaded functions: {qual_name}"
- )
- def make_stubs_for_overloads(overload_info):
- overload_stubs = []
- for orig_fn, overloads in overload_info.items():
- orig_ast = get_jit_def(
- orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule"
- )
- for overload_name, overload_fn in overloads:
- _check_no_signature(overload_fn)
- over_ast = get_jit_def(
- overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule"
- )
- new_ast = torch._C._replace_overloaded_method_decl(
- over_ast.decl(), orig_ast, overload_name
- )
- _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
- overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
- return overload_stubs
- def check_module_initialized(mod):
- assert isinstance(mod, torch.nn.Module)
- if not hasattr(mod, "_parameters"):
- raise RuntimeError(
- f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?"
- )
- # This is to avoid importing torch.distributed.nn
- if not hasattr(mod, "remote_parameters"):
- for name, param in mod._parameters.items():
- if param is not None and torch.nn.parameter.is_lazy(param):
- raise RuntimeError(
- f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?"
- )
- for name, buf in mod._buffers.items():
- if buf is not None and torch.nn.parameter.is_lazy(buf):
- raise RuntimeError(
- f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?"
- )
- def infer_methods_to_compile(nn_module):
- """Implement the default rules for which methods should act as starting points for compilation.
- (TODO add a link when the rules are published).
- """
- check_module_initialized(nn_module)
- user_annotated_ignored_attributes = getattr(
- nn_module, "__jit_ignored_attributes__", list()
- )
- ignored_properties = jit_ignored_properties(nn_module)
- methods: List[str] = []
- if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
- nn_module.forward
- ):
- forward_func = getattr(nn_module.forward, "__func__", None)
- module_forward = getattr(torch.nn.Module, "forward", None)
- if forward_func != module_forward:
- methods = ["forward"]
- exported = []
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if (
- _jit_internal.get_torchscript_modifier(item)
- is _jit_internal.FunctionModifiers.EXPORT
- ):
- exported.append(name)
- methods = methods + exported
- overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
- overload_info = get_overload_annotations(nn_module, ignored_properties)
- overload_name_mappings.update(get_overload_name_mapping(overload_info))
- overload_stubs = make_stubs_for_overloads(overload_info)
- nn_module.__overloads__ = overload_name_mappings
- # we shouldn't directly compile overloaded methods, just its overloads
- def ignore_overloaded(method_name):
- return method_name not in overload_name_mappings
- filtered_methods = filter(ignore_overloaded, methods)
- # Unique the methods. We don't want to use a set to store the methods because it
- # introduces non-determinism to compile order.
- uniquer: Set[str] = set()
- uniqued_methods = []
- for name in filtered_methods:
- if name in uniquer:
- continue
- uniqued_methods.append(name)
- uniquer.add(name)
- stubs = []
- for method in uniqued_methods:
- stubs.append(make_stub_from_method(nn_module, method))
- return overload_stubs + stubs
- def get_hook_stubs(nn_module):
- """Return forward hook and pre_hook ScriptModuleStubs."""
- check_module_initialized(nn_module)
- hook_map: Dict = {}
- hook_stubs = []
- for hook in nn_module._forward_hooks.values():
- if hook.__name__ in hook_map:
- if id(hook) != id(hook_map[hook.__name__]):
- raise RuntimeError(
- f"Hook '{hook.__name__}' on {type(nn_module).__name__} "
- "has at least two different python definitions."
- " Please use unique names for all hooks."
- )
- else:
- hook_map[hook.__name__] = hook
- hook_stubs.append(make_stub(hook, hook.__name__))
- pre_hook_stubs = []
- for pre_hook in nn_module._forward_pre_hooks.values():
- if pre_hook.__name__ in hook_map:
- if id(pre_hook) != id(hook_map[pre_hook.__name__]):
- raise RuntimeError(
- f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} "
- "has at least two different python definitions."
- " Please use unique names for all hooks."
- )
- else:
- hook_map[pre_hook.__name__] = pre_hook
- pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__))
- return hook_stubs, pre_hook_stubs
- def get_property_stubs(nn_module):
- """Create property stubs for the properties of the module by creating method stubs for the getter and setter."""
- module_ty = type(nn_module)
- properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
- rcbs = {}
- for name in dir(module_ty):
- item = getattr(module_ty, name, None)
- if isinstance(item, property):
- if not item.fget:
- raise RuntimeError(
- f"Property {name} of {nn_module.__name__} must have a getter"
- )
- rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
- stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
- return stubs
- def interface_script(mod_interface, nn_module):
- """
- Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile.
- Args:
- mod_interface: the interface type that the module have
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- """
- if isinstance(nn_module, torch.jit.ScriptModule):
- return nn_module
- check_module_initialized(nn_module)
- def infer_interface_methods_to_compile(nn_module):
- """Rule to infer the methods from the interface type.
- It is used to know which methods need to act as starting points for compilation.
- """
- stubs = []
- for method in mod_interface.getMethodNames():
- stubs.append(make_stub_from_method(nn_module, method))
- return stubs
- return create_script_module(nn_module, infer_interface_methods_to_compile)
- def try_compile_fn(fn, loc):
- if _jit_internal.is_ignored_fn(fn):
- # Don't do anything for @ignore'd functions
- return None
- if isinstance(fn, torch.nn.Module):
- # Since modules are callable pybind recognizes them as functions, but
- # don't do anything for them
- return None
- if not inspect.isfunction(fn) and not inspect.ismethod(fn):
- raise RuntimeError(
- f"`{fn}` is not a function. Recursive scripting only supports "
- "Python functions or methods currently.\n"
- f"Consider manually annotating `{fn}` with @torch.jit.script."
- )
- # The object returned by __prepare_scriptable__ might have a different closure.
- # Resolve it here to get the right resolution callback.
- fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator]
- # We don't have the actual scope where the function was defined, but we can
- # extract the necessary info from the closed over variables on the function
- # object
- rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
- return torch.jit.script(fn, _rcb=rcb)
- def wrap_cpp_class(cpp_class):
- """Wrap this torch._C.Object in a Python RecursiveScriptClass."""
- return torch.jit.RecursiveScriptClass(cpp_class)
- def wrap_cpp_module(cpp_module):
- """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
- def init_fn(script_module):
- for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
- setattr(script_module, name, wrap_cpp_module(cpp_module))
- script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
- script_module._c._type()
- )
- for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
- script_module._forward_pre_hooks[idx] = fn
- for idx, fn in enumerate(script_module._c._get_forward_hooks()):
- script_module._forward_hooks[idx] = fn
- return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- def compile_unbound_method(concrete_type, fn):
- if _jit_internal.is_ignored_fn(fn):
- return None
- stub = make_stub(fn, fn.__name__)
- with torch._jit_internal._disable_emit_hooks():
- # We don't want to call the hooks here since the graph that is calling
- # this function is not yet complete
- create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
- return stub
- def lazy_bind(concrete_type, unbound_method):
- """
- Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method.
- We do this so that any Python shenanigans that
- will poison type sharing are impossible at compile time.
- """
- def lazy_binding_method(cpp_module, *args):
- def init_fn(script_module):
- orig_class = concrete_type.py_class
- # Copy @ignored/@unused methods from the original module to the new one.
- # This ensures they are available during execution.
- for name in dir(orig_class):
- item = getattr(orig_class, name, None)
- if _jit_internal.is_ignored_fn(item):
- setattr(script_module, name, item)
- # Copy constants over so they are available during execution.
- for name, value in concrete_type.get_constants().items():
- setattr(script_module, name, value)
- script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- method = types.MethodType(unbound_method, script_module)
- return method(*args)
- # make the lazy binding method "look like" the original method
- lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined]
- lazy_binding_method.__name__ = unbound_method.__name__
- torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
- return lazy_binding_method
|