| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549 |
- # mypy: allow-untyped-defs
- import dataclasses
- import functools
- import inspect
- import logging
- import re
- import time
- import warnings
- from contextlib import contextmanager, nullcontext
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
- import torch
- import torch._dynamo
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dynamo.exc import UserError, UserErrorType
- from torch._export.non_strict_utils import (
- _fakify_script_objects,
- _gather_constant_attrs,
- make_constraints,
- make_fake_inputs,
- make_fake_params_buffers,
- produce_guards_and_solve_constraints,
- )
- from torch._export.passes._node_metadata_hook import (
- _node_metadata_hook,
- _set_node_metadata_hook,
- )
- from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
- _AddRuntimeAssertionsForInlineConstraintsPass,
- )
- from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
- from torch._export.passes.lift_constants_pass import (
- ConstantAttrMap,
- lift_constants_pass,
- rewrite_script_object_meta,
- )
- from torch._export.utils import placeholder_naming_pass, placeholder_prefixes
- from torch._export.verifier import SpecViolationError
- from torch._export.wrappers import _wrap_submodules
- from torch._functorch.aot_autograd import aot_export_module
- from torch._guards import detect_fake_mode
- from torch._library.fake_class_registry import FakeScriptObject
- from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
- from torch._utils_internal import log_export_usage
- from torch.export.dynamic_shapes import _combine_args
- from torch.export.exported_program import OutputKind
- from torch.fx._utils import first_call_function_nn_module_stack
- from torch.fx.experimental.symbolic_shapes import (
- ConstraintViolationError,
- free_unbacked_symbols,
- GuardOnDataDependentSymNode,
- ShapeEnv,
- )
- from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
- from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
- from torch.utils._pytree import TreeSpec
- from torch.utils._sympy.value_ranges import ValueRangeError
- from ._safeguard import AutogradStateOpsFailSafeguard
- from .exported_program import (
- _disable_prexisiting_fake_mode,
- ExportedProgram,
- InputKind,
- ModuleCallEntry,
- ModuleCallSignature,
- )
- from .graph_signature import (
- _sig_to_specs,
- ArgumentSpec,
- ConstantArgument,
- CustomObjArgument,
- ExportGraphSignature,
- SymIntArgument,
- TensorArgument,
- TokenArgument,
- )
- log = logging.getLogger(__name__)
- @dataclasses.dataclass
- class ExportDynamoConfig:
- """
- Manage Export-specific configurations of Dynamo.
- """
- allow_rnn: bool = True
- reorderable_logging_functions: Set[Callable] = dataclasses.field(
- default_factory=set
- )
- @dataclasses.dataclass
- class ExportedArtifact:
- gm: torch.fx.GraphModule
- sig: ExportGraphSignature
- constants: Dict[
- str,
- Union[
- torch.Tensor,
- FakeScriptObject,
- torch.ScriptObject,
- ],
- ]
- out_spec: Optional[TreeSpec] = None # type: ignore[Incompatible types in assignment]
- fake_mode: Optional[FakeTensorMode] = None # type: ignore[Incompatible types in assignment]
- module_call_specs: Optional[Dict[str, Dict[str, pytree.TreeSpec]]] = None # type: ignore[Incompatible types in assignment]
- DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
- DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
- logging.critical,
- logging.debug,
- logging.error,
- logging.exception,
- logging.info,
- logging.log,
- logging.warning,
- print,
- warnings.warn,
- }
- @contextmanager
- def _ignore_backend_decomps():
- orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
- orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
- try:
- yield
- finally:
- torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
- torch.backends.nnpack.set_flags(*orig_nnpack_flag)
- def _fixup_key(x):
- return "L__self__" + _strip_root(x)
- def _strip_root(x):
- if isinstance(x, str) and x.startswith("_export_root"):
- stripped = x[len("_export_root") :]
- return stripped[1:] if stripped.startswith(".") else stripped
- return x
- def _add_runtime_assertions_to_cond_in_subgraph(range_constraints, gm, fake_mode):
- # We can't get rid of this yet, since for some reason
- # insert_deferred_runtime_assertions doesn't add assertions to cond
- # subgraphs
- if len(range_constraints) > 0:
- stack_trace = (
- 'File "torch/_export/passes/add_runtime_assertions_for_constraints_pass.py", line 46, '
- "in _AddRuntimeAssertionsForInlineConstraintsPass"
- )
- with fake_mode, _set_node_metadata_hook(
- gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
- ):
- res = _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints)(gm)
- assert res is not None
- gm = res.graph_module
- def _rewrite_node(gm):
- for node in gm.graph.nodes:
- if node.target == torch.ops.higher_order._export_tracepoint:
- if "path" in node.kwargs:
- path = _strip_root(node.kwargs["path"])
- with gm.graph.inserting_before(node):
- new_node = gm.graph.create_node(
- "call_function",
- torch.ops.higher_order._export_tracepoint,
- args=node.args,
- kwargs={
- "path": path,
- "kind": node.kwargs["kind"],
- },
- )
- new_node.meta = node.meta
- node.replace_all_uses_with(new_node)
- gm.graph.erase_node(node)
- def _convert_input_to_fake(gm, args, kwargs):
- params_buffers = _get_params_buffers(gm)
- fake_inps: List[torch.Tensor] = []
- for node in gm.graph.nodes:
- if node.op == "placeholder" and "val" in node.meta:
- fake_val = node.meta["val"]
- if fake_val is not None and isinstance(fake_val, torch.Tensor):
- fake_inps.append(fake_val)
- if detected_fake_mode := detect_fake_mode(fake_inps):
- fake_mode = detected_fake_mode
- else:
- fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
- if len(args) == 0 and len(kwargs) == 0:
- return (), {}, params_buffers, fake_mode
- count = 0
- def convert_to_fake(x):
- nonlocal count
- val = fake_inps[count]
- count += 1
- return val
- fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
- # TODO properly use the cached fake tensor
- fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
- fake_params_buffers = pytree.tree_map_only(
- torch.Tensor,
- functools.partial(fake_mode.from_tensor, static_shapes=True),
- params_buffers,
- )
- return fake_args, fake_kwargs, fake_params_buffers, fake_mode
- def _replace_param_buffer_names(param_buffer_table, sig):
- for spec in sig.input_specs:
- if spec.kind in (
- InputKind.PARAMETER,
- InputKind.BUFFER,
- ):
- spec.target = param_buffer_table[spec.target]
- for spec in sig.output_specs:
- if spec.kind in (
- OutputKind.BUFFER_MUTATION,
- OutputKind.GRADIENT_TO_PARAMETER,
- ):
- spec.target = param_buffer_table[spec.target]
- def _convert_to_positional_args(orig_arg_names, args, kwargs):
- assert len(orig_arg_names) == len(args) + len(kwargs), (
- f"Total number of arg names is expected to be {len(orig_arg_names)} "
- f"but got {len(args)} positional args, {len(kwargs)} kwargs."
- )
- reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
- return (
- *args,
- *reordered_kwargs,
- )
- def _normalize_nn_module_stack(gm_torch_level, root_cls):
- # Append a root module to every nn_module_stack.
- root = "L['self']"
- root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
- for gm in gm_torch_level.modules():
- if not isinstance(gm, torch.fx.GraphModule):
- continue
- for node in gm.graph.nodes:
- if node.op in ["placeholder", "output"]:
- continue
- add_root = True
- if nn_module_stack := node.meta.get("nn_module_stack", {}):
- path, ty = next(iter(nn_module_stack.values()))
- # After deserializing the class `ty` might not exist anymore so
- # it could be a string
- if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
- # TODO Figure out why sometimes we have root sometimes we don't.
- if path == root and ty is root_cls:
- add_root = False
- else:
- assert isinstance(ty, str)
- if add_root:
- def normalize_path(path):
- try:
- parts = []
- class Path:
- def __getattr__(self, name):
- parts.append(name)
- return self
- def __getitem__(self, idx):
- parts.append(str(idx))
- return self
- eval(path, {"L": {"self": Path()}})
- return ".".join(parts)
- except Exception: # TODO(zhxchen17) Remove this.
- return path
- nn_module_stack = {
- root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
- **nn_module_stack,
- }
- node.meta["nn_module_stack"] = {
- key: (normalize_path(path), ty)
- for key, (path, ty) in nn_module_stack.items()
- }
- def _get_param_buffer_mapping(
- original_module: torch.nn.Module,
- traced_module: torch.nn.Module,
- ) -> Dict[str, str]:
- """
- Returns a mapping of parameter/buffer names from the new module to the
- original model. This is to help with restoring the FQN for parameter/buffers
- of a traced module to what the original module contains.
- """
- param_lookup: Dict[int, List[str]] = {}
- buffer_lookup: Dict[int, List[str]] = {}
- for name, param in original_module.named_parameters(remove_duplicate=False):
- param_lookup.setdefault(id(param), []).append(name)
- for name, buffer in original_module.named_buffers(remove_duplicate=False):
- buffer_lookup.setdefault(id(buffer), []).append(name)
- # reverse lists so FQN assignment is FIFO wrt model structure
- for name, fqns in param_lookup.items():
- param_lookup[name] = fqns[::-1]
- for name, fqns in buffer_lookup.items():
- buffer_lookup[name] = fqns[::-1]
- param_buffer_table: Dict[str, str] = {}
- for dynamo_name, dynamo_param in traced_module.named_parameters(
- remove_duplicate=False
- ):
- assert dynamo_name not in param_buffer_table
- if id(dynamo_param) in param_lookup:
- param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop()
- for dynamo_name, dynamo_buffer in traced_module.named_buffers(
- remove_duplicate=False
- ):
- assert dynamo_name not in param_buffer_table
- if id(dynamo_buffer) in buffer_lookup:
- param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()
- return param_buffer_table
- def _remap_constants(
- orig_constant_attrs: ConstantAttrMap,
- graph_signature: ExportGraphSignature,
- constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
- ) -> None:
- """Rewrite the graph signature and constants table to use the FQN from the original module."""
- remap_table: Dict[str, List[str]] = {}
- for name, value in constants.items():
- if value in orig_constant_attrs:
- remap_table[name] = orig_constant_attrs[value]
- for spec in graph_signature.input_specs:
- if spec.kind in (
- InputKind.CONSTANT_TENSOR,
- InputKind.CUSTOM_OBJ,
- ):
- orig_target = spec.target
- assert orig_target is not None
- targets = remap_table.get(orig_target, [orig_target])
- spec.target = targets[0]
- constant = constants[orig_target]
- del constants[orig_target]
- for target in targets:
- constants[target] = constant
- def _rename_constants_nodes(
- gm: torch.fx.GraphModule,
- graph_signature: ExportGraphSignature,
- ) -> None:
- """
- For strict mode, rename constants nodes that were previously annotated as buffers.
- """
- # handle name collisions with existing constants
- node_names = {node.name for node in gm.graph.nodes}
- def rename_constant(name):
- if name in node_names:
- n = 1
- while (dup_name := f"{name}_{n}") in node_names:
- n += 1
- name = dup_name
- node_names.add(name)
- return name
- # use input specs to map names from buffers to constants
- buffer_prefix = placeholder_prefixes[InputKind.BUFFER]
- const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR]
- buffer_to_constant = {}
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
- const_prefix
- ):
- if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
- c_name = rename_constant(
- const_prefix + spec.arg.name[len(buffer_prefix) :]
- )
- else: # lifted constant
- c_name = rename_constant(const_prefix + spec.arg.name)
- buffer_to_constant[spec.arg.name] = c_name
- spec.arg.name = c_name
- for spec in graph_signature.output_specs:
- if spec.arg.name in buffer_to_constant:
- spec.arg.name = buffer_to_constant[spec.arg.name]
- # Rename constants nodes for all modules
- for mod in gm.modules():
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.name in buffer_to_constant:
- node.name = node.target = buffer_to_constant[node.name]
- mod.recompile()
- def _restore_state_dict(
- original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
- ) -> None:
- """
- Restores the state dict of the traced module to that of the original module.
- """
- param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
- # Since the graph module is flattened (no module heirarchy), we
- # need to noramlize the module by replacing "." with "_". If we
- # don't, it will try to save the weight to a submodule which no
- # longer exists.
- for name, fqn in param_buffer_table.items():
- param_buffer_table[name] = fqn.replace(".", "_")
- # Replace state dict attr names with the fqn
- for name, fqn in param_buffer_table.items():
- if not hasattr(traced_module, name):
- continue
- attr = getattr(traced_module, name)
- if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
- traced_module.register_buffer(fqn, attr)
- else:
- setattr(traced_module, fqn, attr)
- delattr(traced_module, name)
- # Replace graph getattr nodes with the correct name
- for node in traced_module.graph.nodes:
- if node.op == "get_attr":
- attr_name = node.target
- if attr_name in param_buffer_table:
- node.target = param_buffer_table[attr_name]
- traced_module.recompile()
- def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
- return {
- name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
- }
- def _make_module_call_graph(
- module_hierarchy: Dict[str, str],
- in_spec: TreeSpec,
- out_spec: TreeSpec,
- module_call_signatures: Dict[str, ModuleCallSignature],
- ) -> List[ModuleCallEntry]:
- ret = [
- ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
- for fqn in module_hierarchy
- ]
- assert ret[0].fqn == ""
- ret[0].signature = ModuleCallSignature(
- inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
- )
- return ret
- def _export_to_torch_ir(
- f: Callable,
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
- *,
- preserve_module_call_signature: Tuple[str, ...] = (),
- disable_constraint_solver: bool = False,
- _allow_complex_guards_as_runtime_asserts: bool = False,
- restore_fqn: bool = True,
- _log_export_usage: bool = True,
- same_signature: bool = True,
- ) -> torch.fx.GraphModule:
- """
- Traces either an nn.Module's forward function or just a callable with PyTorch
- operations inside and produce a torch.fx.GraphModule in torch IR.
- """
- if _log_export_usage:
- log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
- if not isinstance(args, tuple):
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
- )
- kwargs = kwargs or {}
- with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
- try:
- module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
- with _wrap_submodules(
- f, preserve_module_call_signature, module_call_specs
- ), _ignore_backend_decomps():
- gm_torch_level, _ = torch._dynamo.export(
- f,
- dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
- assume_static_by_default=True,
- tracing_mode="symbolic",
- disable_constraint_solver=disable_constraint_solver,
- # currently the following 2 flags are tied together for export purposes,
- # but untangle for sake of dynamo export api
- prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts,
- _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
- _log_export_usage=_log_export_usage,
- same_signature=same_signature,
- )(
- *args,
- **kwargs,
- )
- except (ConstraintViolationError, ValueRangeError) as e:
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
- except GuardOnDataDependentSymNode as e:
- raise UserError( # noqa: B904
- UserErrorType.ANTI_PATTERN,
- f"Consider annotating your code using torch._check*(). {str(e)}",
- case_name="constrain_as_size_example",
- )
- gm_torch_level.meta["module_call_specs"] = module_call_specs
- if isinstance(f, torch.nn.Module) and restore_fqn:
- _restore_state_dict(f, gm_torch_level)
- return gm_torch_level
- def _export_to_aten_ir(
- mod: torch.nn.Module,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- constant_attrs: ConstantAttrMap,
- *,
- transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
- pre_dispatch=False,
- _is_torch_jit_trace=False,
- ):
- # [NOTE] If the user is exporting under training mode, we want to detect if there is any
- # state change in the autograd global state and error. If the user is exporting under inference
- # mode, we don't care. At predispatch level, we don't care about the state change.
- is_grad_enabled = torch._C.is_grad_enabled()
- grad_safe_guard = nullcontext()
- if not pre_dispatch and is_grad_enabled:
- grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment]
- @contextmanager
- def _compiling_state_context():
- old_value = torch.compiler._is_compiling_flag
- try:
- torch.compiler._is_compiling_flag = True
- yield
- finally:
- torch.compiler._is_compiling_flag = old_value
- # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
- # otherwise aot_export_module will error out because it sees a mix of fake_modes.
- # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
- with torch.nn.utils.stateless._reparametrize_module(
- mod,
- fake_params_buffers,
- tie_weights=True,
- strict=True,
- stack_weights=True,
- ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined]
- gm, graph_signature = transform(aot_export_module)(
- mod,
- fake_args,
- trace_joint=False,
- pre_dispatch=pre_dispatch,
- kwargs=fake_kwargs,
- )
- # TODO unfortunately preserving graph-level metadata is not
- # working well with aot_export. So we manually copy it.
- # (The node-level meta is addressed above.)
- if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
- gm.meta.update(mod.meta)
- def make_argument_spec(i, node) -> ArgumentSpec:
- if isinstance(node, (int, bool, float, type(None))):
- # For const outputs we just directly return this
- return ConstantArgument(name="", value=node)
- assert (
- "val" in node.meta
- ), f"{node} is not a constant or a node with a 'val' metadata field"
- val = node.meta["val"]
- if i < len(graph_signature.input_tokens):
- # TODO: We should be checking for a different type, once we add a new type
- return TokenArgument(name=node.name)
- elif isinstance(val, FakeTensor):
- return TensorArgument(name=node.name)
- elif isinstance(val, torch.SymInt):
- return SymIntArgument(name=node.name)
- elif isinstance(val, torch.ScriptObject):
- return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined]
- elif isinstance(val, FakeScriptObject):
- return CustomObjArgument(name=node.name, class_fqn=val.script_class_name)
- elif isinstance(val, (int, bool, str, float, type(None))):
- return ConstantArgument(name=node.name, value=val)
- else:
- raise AssertionError(
- f"Encountered an unsupported object of type {type(val)} "
- f"while writing the metadata for exported program"
- )
- is_joint = graph_signature.backward_signature is not None
- # NOTE: aot_export adds symint metadata for placeholders with int values;
- # since these become specialized, we replace such metadata with the original values
- flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
- index = 0
- total_non_user_inputs = (
- len(graph_signature.parameters)
- + len(graph_signature.buffers)
- + len(graph_signature.input_tokens)
- )
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- if index >= total_non_user_inputs:
- user_arg = flat_args[index - total_non_user_inputs]
- if not isinstance(user_arg, torch.Tensor):
- node.meta["val"] = user_arg
- index += 1
- input_specs, output_specs = _sig_to_specs(
- user_inputs=set(graph_signature.user_inputs),
- inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type]
- inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type]
- user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type]
- buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type]
- user_input_mutations=graph_signature.user_inputs_to_mutate, # type: ignore[arg-type]
- grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr]
- grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr]
- loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr]
- inputs=[
- make_argument_spec(i, node)
- for i, node in enumerate(gm.graph.nodes)
- if node.op == "placeholder"
- ],
- outputs=[
- make_argument_spec(i, node)
- for i, node in enumerate(
- pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
- )
- ],
- input_tokens=graph_signature.input_tokens,
- output_tokens=graph_signature.output_tokens,
- )
- export_graph_signature = ExportGraphSignature(
- input_specs=input_specs, output_specs=output_specs
- )
- from torch._guards import detect_fake_mode
- fake_mode = detect_fake_mode(flat_args)
- from torch._dynamo import config as _dynamo_config
- if not _dynamo_config.do_not_emit_runtime_asserts:
- stack_trace = (
- 'File "torch/fx/passes/runtime_assert.py", line 24, '
- "in insert_deferred_runtime_asserts"
- )
- with _set_node_metadata_hook(
- gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
- ):
- insert_deferred_runtime_asserts(
- gm,
- fake_mode.shape_env,
- f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
- export=True,
- )
- if pre_dispatch:
- from torch._export.passes.replace_set_grad_with_hop_pass import (
- replace_set_grad_with_hop_pass,
- )
- gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)
- # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
- for _mod in gm.modules():
- if not isinstance(_mod, torch.fx.GraphModule):
- continue
- for node in _mod.graph.nodes:
- if node.op in ["placeholder", "output"]:
- node.meta.pop("nn_module_stack", None)
- node.meta.pop("stack_trace", None)
- constants = rewrite_script_object_meta(gm)
- constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
- # Prettify names for placeholder nodes.
- placeholder_naming_pass(
- gm,
- export_graph_signature,
- mod,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- constants,
- )
- return ExportedArtifact(
- gm,
- export_graph_signature,
- constants,
- )
- def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]:
- params_buffers: Dict[str, torch.Tensor] = {}
- for name, param in mod.named_parameters(remove_duplicate=False):
- params_buffers[name] = param
- for name, buffer in mod.named_buffers(remove_duplicate=False):
- params_buffers[name] = buffer
- return params_buffers
- def _get_forward_arg_names(
- mod: torch.nn.Module,
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- ) -> List[str]:
- """
- Gets the argument names to forward that are used, for restoring the
- original signature when unlifting the exported program module.
- - Positional args: retain the original argument names, and enumerate
- *args as args_0, args_1, ...
- - Keyword args: retain the original kwarg names in the order specified
- by the user. This order seems to matter for the current state of
- export lifted modules.
- """
- sig = inspect.signature(mod.forward)
- _args = sig.bind_partial(*args).arguments
- names: List[str] = []
- for name, value in _args.items():
- # handle variable number of positional args
- if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
- names.extend([f"{name}_{i}" for i, _ in enumerate(value)])
- else:
- names.append(name)
- # order of kwargs matters for input spec
- if kwargs:
- names.extend([kwarg for kwarg, _ in kwargs.items()])
- return names
- def _rewrite_dynamo_tensor_constants(
- orig_mod_buffers: Set[torch.Tensor],
- traced_mod_buffers: Dict[str, torch.Tensor],
- graph_signature: ExportGraphSignature,
- constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
- ):
- """Dynamo erroneously marks tensor attributes on modules as a buffers.
- Rewrite them to be tensor constants.
- """
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.BUFFER:
- assert spec.target is not None
- value = traced_mod_buffers[spec.target]
- if value not in orig_mod_buffers:
- # This was a tensor constant erroneously marked as a buffer.
- # Convert it int oa constant in the graph signature, and add its
- # value to the constants table.
- spec.kind = InputKind.CONSTANT_TENSOR
- constants[spec.target] = value
- def _rewrite_non_persistent_buffers(
- orig_mod: torch.nn.Module,
- graph_signature: ExportGraphSignature,
- constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
- ):
- """Dynamo erroneously drops the persistent flag on buffers.
- Rewrite non-persistent buffers to reflect the original module.
- """
- state_dict = orig_mod.state_dict()
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.BUFFER:
- assert spec.target is not None
- if spec.target not in state_dict:
- assert spec.target not in constants
- spec.persistent = False
- constants[spec.target] = orig_mod.get_buffer(spec.target)
- def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None:
- """
- Perform nn_module_stack checks on the graph.
- Current constraints:
- For the top level graph:
- - populated for 'call_function', 'get_attr'
- - None for 'placeholder', 'output'
- For submodule graphs:
- - None for 'placeholder', output'
- TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules.
- """
- # Check top-level graph for all nodes, all graphs for placeholder & output nodes
- for i, mod in enumerate([graph_module] + list(graph_module.modules())):
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.op in ["call_function", "get_attr"]:
- if i == 0:
- if (
- nn_module_stack := node.meta.get("nn_module_stack", None)
- ) is None:
- raise SpecViolationError(
- f"Node {node} of type {node.op} is missing nn_module_stack metadata"
- )
- if not all(
- isinstance(k, str)
- and isinstance(v, tuple)
- and len(v) == 2
- and all(isinstance(x, str) for x in v)
- for k, v in nn_module_stack.items()
- ):
- raise SpecViolationError(
- f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format"
- f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}"
- )
- elif node.op in ["placeholder", "output"]:
- if node.meta.get("nn_module_stack", None):
- raise SpecViolationError(
- f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None"
- )
- def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
- """
- Perform stack trace checks on the graph.
- Constraints:
- - None or non-empty str for 'call_function', 'get_attr'
- - None for 'placeholder', 'output'
- """
- for i, mod in enumerate([graph_module] + list(graph_module.modules())):
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in graph_module.graph.nodes:
- stack_trace = node.meta.get("stack_trace", None)
- if node.op in ["call_function", "get_attr"]:
- if not (stack_trace is None or isinstance(stack_trace, str)):
- raise SpecViolationError(
- f"Node {node} of type {node.op} has invalid stack_trace metadata, "
- f"expected a string or None but instead found: {stack_trace}"
- )
- elif node.op in ["placeholder", "output"]:
- if stack_trace:
- raise SpecViolationError(
- f"Node {node} of type {node.op} contains stack_trace metadata, "
- f"expected None but instead found: {stack_trace}"
- )
- def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignature):
- """
- Performs a sanity check on the placeholder node names.
- - User input nodes: no restrictions, should match the original forward() signature
- - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes>
- """
- name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs}
- for mod in gm.modules():
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.op == "placeholder":
- if node.name not in name_to_kind:
- continue
- node_kind = name_to_kind[node.name]
- prefix = placeholder_prefixes[node_kind]
- if not node.name.startswith(prefix):
- raise SpecViolationError(
- f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
- )
- def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
- op_count = 0
- op_set = set()
- for m in ep.graph_module.modules():
- if not isinstance(m, torch.fx.GraphModule):
- continue
- for node in m.graph.nodes:
- if node.op != "call_function":
- continue
- op_count += 1
- assert hasattr(node.target, "__module__")
- assert hasattr(node.target, "__name__")
- op_set.add(f"{node.target.__module__}.{node.target.__name__}")
- return {"op_count": op_count, "op_set": op_set}
- _EXPORT_FLAGS: Optional[Set[str]] = None
- _EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
- def _log_export_wrapper(fn):
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
- try:
- start = time.time()
- ep = fn(*args, **kwargs)
- end = time.time()
- log_export_usage(
- event="export.time",
- metrics=end - start,
- flags=_EXPORT_FLAGS,
- **get_ep_stats(ep),
- )
- except Exception as e:
- t = type(e)
- error_type = t.__module__ + "." + t.__qualname__
- log_export_usage(
- event="export.error",
- type=error_type,
- message=str(e),
- flags=_EXPORT_FLAGS,
- )
- raise e
- finally:
- _EXPORT_FLAGS = None
- _EXPORT_MODULE_HIERARCHY = None
- return ep
- return wrapper
- def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
- if not isinstance(example_inputs, (tuple, list, dict)):
- example_inputs = (example_inputs,)
- elif isinstance(example_inputs, list):
- example_inputs = tuple(example_inputs)
- elif (
- isinstance(example_inputs, (torch.Tensor, dict))
- and example_kwarg_inputs is None
- ):
- example_inputs = (example_inputs,)
- if example_kwarg_inputs is None:
- example_kwarg_inputs = {}
- return example_inputs, example_kwarg_inputs
- @contextmanager
- def patch_forward(obj: torch.nn.Module, new_method):
- """Helper method to make it easier to cleanly torch.export() a method on a
- module that is not `forward`.
- """
- # Save the original method
- original_method = obj.forward
- # Patch the method
- obj.forward = new_method.__get__(obj, obj.__class__)
- try:
- yield
- finally:
- # Restore the original method
- obj.forward = original_method
- @contextmanager
- def _temp_disable_texpr_fuser():
- original_state = torch._C._jit_texpr_fuser_enabled()
- torch._C._jit_set_texpr_fuser_enabled(False)
- try:
- yield
- finally:
- torch._C._jit_set_texpr_fuser_enabled(original_state)
- class _WrapperModule(torch.nn.Module):
- def __init__(self, f):
- super().__init__()
- self.f = f
- def forward(self, *args, **kwargs):
- return self.f(*args, **kwargs)
- def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
- with _temp_disable_texpr_fuser():
- from torch.jit._trace import TopLevelTracedModule
- export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs)
- if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator]
- return _export(
- traced_callable,
- export_args,
- export_kwargs,
- strict=False,
- _is_torch_jit_trace=True,
- ).module()
- elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
- traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator]
- ):
- with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
- return _export(
- traced_callable.owner(), # type: ignore[operator]
- export_args,
- export_kwargs,
- strict=False,
- _is_torch_jit_trace=True,
- ).module()
- else:
- return _export(
- _WrapperModule(traced_callable),
- export_args,
- export_kwargs,
- strict=False,
- _is_torch_jit_trace=True,
- ).module()
- def _strict_export(
- mod: torch.nn.Module,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
- preserve_module_call_signature: Tuple[str, ...],
- pre_dispatch: bool,
- original_state_dict: Dict[str, Any],
- orig_in_spec: TreeSpec,
- _allow_complex_guards_as_runtime_asserts: bool,
- _disable_forced_specializations: Optional[bool],
- _is_torch_jit_trace: bool,
- ):
- gm_torch_level = _export_to_torch_ir(
- mod,
- args,
- kwargs,
- dynamic_shapes,
- preserve_module_call_signature=preserve_module_call_signature,
- restore_fqn=False, # don't need to restore because we will do it later
- _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
- _log_export_usage=False,
- )
- # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
- (
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- dynamo_fake_mode,
- ) = _convert_input_to_fake(gm_torch_level, args, kwargs)
- # First, we want to pass through the graph to try populating
- # val field for getattr if there is anything missing.
- # This can happen when quantization adds extra params and forgets
- # to update "val"
- for node in gm_torch_level.graph.nodes:
- if node.op == "get_attr" and "val" not in node.meta:
- attr = getattr(gm_torch_level, node.target)
- # Checks if it is not a HigherOrderOp branch or a module
- if not isinstance(attr, torch.nn.Module):
- assert (
- dynamo_fake_mode is not None
- ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
- node.meta["val"] = dynamo_fake_mode.from_tensor(
- attr, static_shapes=True
- )
- # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace)
- # from the param nodes as they are treated as fresh inputs
- # Therefore, we manually extract them before calling into aot_export
- params_buffers_to_node_meta = {}
- for node in gm_torch_level.graph.nodes:
- target = node.target
- meta = node.meta
- if node.op == "call_module":
- submodule = getattr(gm_torch_level, target)
- if isinstance(submodule, torch.nn.Module):
- for name, _ in submodule.named_parameters(
- recurse=True, remove_duplicate=False
- ):
- params_buffers_to_node_meta[target + "." + name] = meta
- for name, _ in submodule.named_buffers(
- recurse=True, remove_duplicate=False
- ):
- params_buffers_to_node_meta[target + "." + name] = meta
- if node.op == "get_attr":
- submodule = getattr(gm_torch_level, target)
- if not isinstance(submodule, torch.fx.GraphModule):
- params_buffers_to_node_meta[target] = meta
- # If the call_function uses param as input, we also need to update params' meta
- # with this call_function node's meta.
- # This is basically the same flow as torch.fx.traceback.preserve_meta()
- if node.op == "call_function" and not isinstance(
- node.target, torch._ops.HigherOrderOperator
- ):
- for arg in node._input_nodes:
- if arg.op == "get_attr":
- for entry in torch.fx.proxy._COPY_META_FIELDS:
- if entry in meta:
- params_buffers_to_node_meta[arg.target][entry] = meta[entry]
- # Fix the graph output signature to be tuple if scalar
- out_spec = orig_out_spec = gm_torch_level._out_spec
- # Used to get rid of lint type error.
- assert out_spec is not None
- # aot_export expect the return type to always be a tuple.
- if out_spec.type not in (list, tuple):
- out_spec = pytree.TreeSpec(tuple, None, [out_spec])
- orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
- gm_torch_level.graph._codegen = _PyTreeCodeGen(
- _PyTreeInfo(
- orig_arg_names,
- gm_torch_level._in_spec,
- out_spec,
- )
- )
- gm_torch_level.recompile()
- _normalize_nn_module_stack(gm_torch_level, type(mod))
- # NOTE: graph module expects only positional args
- constant_attrs = _gather_constant_attrs(mod)
- with dynamo_fake_mode:
- aten_export_artifact = _export_to_aten_ir(
- gm_torch_level,
- _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
- {},
- fake_params_buffers,
- constant_attrs,
- pre_dispatch=pre_dispatch,
- )
- # Decompose for readability.
- gm = aten_export_artifact.gm
- export_graph_signature = aten_export_artifact.sig
- constants = aten_export_artifact.constants
- # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes
- for metadata in params_buffers_to_node_meta.values():
- metadata.pop("nn_module_stack", None)
- metadata.pop("stack_trace", None)
- # After aot_export, set the param/buffer metadata back into placeholders
- # Technically, users can still construct this data from param names
- # without relying on this metadata
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- if node.target in export_graph_signature.inputs_to_parameters:
- param_name = export_graph_signature.inputs_to_parameters[node.target]
- if param_name in params_buffers_to_node_meta:
- for k, v in params_buffers_to_node_meta[param_name].items():
- node.meta[k] = v
- if node.target in export_graph_signature.inputs_to_buffers:
- buffer_name = export_graph_signature.inputs_to_buffers[node.target]
- if buffer_name in params_buffers_to_node_meta:
- for k, v in params_buffers_to_node_meta[buffer_name].items():
- node.meta[k] = v
- # Do some cleanups on the graph module to restore the state dict to the
- # expected form. Each of these steps should probably get fixed upstream.
- # 1. Remove tensor constants that were added as buffers.
- _rewrite_dynamo_tensor_constants(
- orig_mod_buffers=set(mod.buffers()),
- traced_mod_buffers=dict(gm_torch_level.named_buffers()),
- graph_signature=export_graph_signature,
- constants=constants,
- )
- # 2. Restore FQN of param/buffers
- param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
- _replace_param_buffer_names(param_buffer_table, export_graph_signature)
- # 3. Remove non-persistent buffers from the graph signature
- _rewrite_non_persistent_buffers(mod, export_graph_signature, constants)
- # 4. Rewrite constants to have the same FQN as the original module.
- _remap_constants(constant_attrs, export_graph_signature, constants)
- # 5. Rename constants nodes in graph module from buffers to constants
- _rename_constants_nodes(gm, export_graph_signature)
- aten_export_artifact.out_spec = orig_out_spec
- aten_export_artifact.fake_mode = dynamo_fake_mode
- aten_export_artifact.module_call_specs = gm_torch_level.meta["module_call_specs"]
- return aten_export_artifact
- def _non_strict_export(
- mod: torch.nn.Module,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
- preserve_module_call_signature: Tuple[str, ...],
- pre_dispatch: bool,
- original_state_dict: Dict[str, Any],
- orig_in_spec: TreeSpec,
- _allow_complex_guards_as_runtime_asserts: bool,
- _disable_forced_specializations: Optional[bool],
- _is_torch_jit_trace: bool,
- ):
- out_spec = None
- module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
- def _tuplify_outputs(aot_export):
- def _aot_export_non_strict(mod, args, kwargs=None, **flags):
- kwargs = kwargs or {}
- class Wrapper(torch.nn.Module):
- def __init__(self, mod):
- super().__init__()
- self._export_root = mod
- def forward(self, *args, **kwargs):
- nonlocal out_spec
- if isinstance(self._export_root, torch.fx.GraphModule):
- with torch.fx.traceback.preserve_node_meta():
- tree_out = torch.fx.Interpreter(self._export_root).run(
- *args, **kwargs
- )
- else:
- tree_out = self._export_root(*args, **kwargs)
- flat_outs, out_spec = pytree.tree_flatten(tree_out)
- return tuple(flat_outs)
- wrapped_mod = Wrapper(mod)
- # Patch export_root to the signatures so that wrapper module correctly populates the
- # in/out spec
- new_preserved_call_signatures = [
- "_export_root." + i for i in preserve_module_call_signature
- ]
- with _wrap_submodules(
- wrapped_mod, new_preserved_call_signatures, module_call_specs
- ):
- gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
- log.debug("Exported program from AOTAutograd:\n%s", gm)
- sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
- sig.buffers = pytree.tree_map(_strip_root, sig.buffers)
- sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers)
- sig.inputs_to_parameters = pytree.tree_map(
- _strip_root, sig.inputs_to_parameters
- )
- sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
- for node in gm.graph.nodes:
- if "nn_module_stack" in node.meta:
- nn_module_stack = node.meta["nn_module_stack"]
- node.meta["nn_module_stack"] = {
- _fixup_key(key): val
- for key, val in pytree.tree_map(
- _strip_root, nn_module_stack
- ).items()
- }
- return gm, sig
- return _aot_export_non_strict
- (
- fake_mode,
- fake_args,
- fake_kwargs,
- equalities_inputs,
- original_signature,
- ) = make_fake_inputs(
- mod,
- args,
- kwargs,
- dynamic_shapes,
- _is_torch_jit_trace=_is_torch_jit_trace,
- _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, # for shape env initialization
- )
- fake_params_buffers = make_fake_params_buffers(fake_mode, _get_params_buffers(mod))
- with fake_mode:
- with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
- patched_mod,
- new_fake_args,
- new_fake_kwargs,
- new_fake_constant_attrs,
- map_fake_to_real,
- ):
- aten_export_artifact = _export_to_aten_ir(
- patched_mod,
- new_fake_args,
- new_fake_kwargs,
- fake_params_buffers,
- new_fake_constant_attrs,
- pre_dispatch=pre_dispatch,
- transform=_tuplify_outputs,
- _is_torch_jit_trace=_is_torch_jit_trace,
- )
- # aten_export_artifact.constants contains only fake script objects, we need to map them back
- aten_export_artifact.constants = {
- fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj
- for fqn, obj in aten_export_artifact.constants.items()
- }
- try:
- produce_guards_and_solve_constraints(
- fake_mode,
- aten_export_artifact.gm,
- dynamic_shapes,
- equalities_inputs,
- original_signature,
- _disable_forced_specializations=_disable_forced_specializations,
- _is_torch_jit_trace=_is_torch_jit_trace,
- )
- except (ConstraintViolationError, ValueRangeError) as e:
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
- _rewrite_non_persistent_buffers(
- mod, aten_export_artifact.sig, aten_export_artifact.constants
- )
- aten_export_artifact.out_spec = out_spec
- aten_export_artifact.fake_mode = fake_mode
- aten_export_artifact.module_call_specs = module_call_specs
- return aten_export_artifact
- @_log_export_wrapper
- @_disable_prexisiting_fake_mode
- def _export(
- mod: torch.nn.Module,
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
- *,
- strict: bool = True,
- preserve_module_call_signature: Tuple[str, ...] = (),
- pre_dispatch: bool = False,
- _allow_complex_guards_as_runtime_asserts: bool = False,
- _disable_forced_specializations: Optional[bool] = False,
- _is_torch_jit_trace: bool = False,
- ) -> ExportedProgram:
- """
- Traces either an nn.Module's forward function or just a callable with PyTorch
- operations inside and produce a ExportedProgram.
- Args:
- f: the `nn.Module` to trace.
- args: example positional inputs.
- kwargs: optional example keyword inputs.
- dynamic_shapes:
- An optional argument where the type should either be:
- 1) a dict from argument names of ``f`` to their dynamic shape specifications,
- 2) a tuple that specifies dynamic shape specifications for each input in original order.
- If you are specifying dynamism on keyword args, you will need to pass them in the order that
- is defined in the original function signature.
- The dynamic shape of a tensor argument can be specified as either
- (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
- not required to include static dimension indices in this dict, but when they are,
- they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
- where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
- are denoted by None. Arguments that are dicts or tuples / lists of tensors are
- recursively specified by using mappings or sequences of contained specifications.
- preserve_module_call_signature: A list of submodule paths for which the original
- calling conventions are preserved as metadata.
- _allow_complex_guards_as_runtime_asserts:
- With the current dynamic shapes language for dims and derived dims, we can run into constraints
- that are not expressible with the language. For example, flattening a matrix and adding to a vector,
- both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
- By default, we either raise a constraint violation error or specialize to static values.
- If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime
- assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops
- required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar).
- Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints
- while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes.
- _disable_forced_specializations:
- Similar to _allow_complex_guards_as_runtime_asserts, but only avoids specializing to static values if set to True.
- For complex guards that don't specialize, this flag doesn't have any effect. Ideally this would be subsumed by
- _allow_complex_guards_as_runtime_asserts, but this handles one additional case: single-variable equalities where
- the symbol is solvable for a concrete value (e.g. Eq(s0 // 4, 400) -> s0 = 1600). If set to True, this flag will
- avoid specializations. Direct equalities (e.g. s0 = 4), will still specialize.
- Returns:
- An ExportedProgram containing the traced method.
- """
- if not isinstance(args, tuple):
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
- )
- if _disable_forced_specializations and strict:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- "_disable_forced_specializations can be only be specified in non-strict mode.",
- )
- global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
- _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
- flags = set()
- flags.add("strict" if strict else "non_strict")
- flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
- log_export_usage(event="export.enter", flags=flags)
- _EXPORT_FLAGS = flags
- kwargs = kwargs or {}
- if isinstance(dynamic_shapes, torch.export.ShapesCollection):
- dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
- flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
- original_state_dict = mod.state_dict(keep_vars=True)
- if not _is_torch_jit_trace:
- forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
- else:
- forward_arg_names = None
- # Call the appropriate export function based on the strictness of tracing.
- export_func = _strict_export if strict else _non_strict_export
- aten_export_artifact = export_func(
- mod,
- args,
- kwargs,
- dynamic_shapes,
- preserve_module_call_signature,
- pre_dispatch,
- original_state_dict,
- orig_in_spec,
- _allow_complex_guards_as_runtime_asserts,
- _disable_forced_specializations,
- _is_torch_jit_trace,
- )
- # Decompose here for readability.
- gm = aten_export_artifact.gm
- export_graph_signature = aten_export_artifact.sig
- out_spec = aten_export_artifact.out_spec
- constants = aten_export_artifact.constants
- fake_mode = aten_export_artifact.fake_mode
- module_call_specs = aten_export_artifact.module_call_specs
- # Add forward args metadata.
- gm.meta["forward_arg_names"] = forward_arg_names
- # The unbacked symint symbols are updated in aot_export
- # so we serialize them here instead of inside dynamo.
- gm.meta["inline_constraints"] = {
- k: v
- for k, v in fake_mode.shape_env.var_to_range.items()
- if free_unbacked_symbols(k)
- }
- num_lifted = next(
- (
- i
- for i, s in enumerate(export_graph_signature.input_specs)
- if s.kind == InputKind.USER_INPUT
- ),
- len(export_graph_signature.input_specs),
- )
- combined_args = _combine_args(
- mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace
- )
- range_constraints = make_constraints(
- fake_mode,
- gm,
- combined_args,
- dynamic_shapes,
- num_lifted,
- )
- if strict:
- _add_runtime_assertions_to_cond_in_subgraph(
- range_constraints,
- gm,
- fake_mode,
- )
- # Make module signatures.
- module_call_signatures = {}
- for fqn, specs in module_call_specs.items():
- mod_fqn = _strip_root(fqn) if not strict else fqn
- module_call_signatures[mod_fqn] = ModuleCallSignature(
- inputs=[], outputs=[], **specs
- )
- if len(preserve_module_call_signature) > 0:
- if not strict:
- _rewrite_node(gm)
- res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
- assert res is not None
- gm = res.graph_module
- assert out_spec is not None
- _verify_nn_module_stack(gm)
- _verify_stack_trace(gm)
- if not _is_torch_jit_trace:
- _verify_placeholder_names(gm, export_graph_signature)
- exported_program = ExportedProgram(
- root=gm,
- graph=gm.graph,
- graph_signature=export_graph_signature,
- state_dict=original_state_dict,
- range_constraints=range_constraints,
- module_call_graph=_make_module_call_graph(
- _EXPORT_MODULE_HIERARCHY,
- orig_in_spec,
- out_spec,
- module_call_signatures,
- ),
- example_inputs=(args, kwargs),
- constants=aten_export_artifact.constants,
- )
- return exported_program
|