| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738 |
- # mypy: allow-untyped-defs
- from collections import defaultdict
- from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
- import torch.utils._pytree as pytree
- from . import _pytree as fx_pytree
- from ._compatibility import compatibility
- import os
- import contextlib
- from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable
- from dataclasses import dataclass
- from contextlib import contextmanager
- import copy
- import enum
- import torch
- import keyword
- import re
- import builtins
- import math
- import warnings
- import inspect
- __all__ = ["PythonCode", "CodeGen", "Graph"]
- if TYPE_CHECKING:
- from .graph_module import GraphModule # noqa: F401
- from ._symbolic_trace import Tracer # noqa: F401
- # Mapping of builtins to their `typing` equivalent.
- _origin_type_map = {
- list: List,
- dict: Dict,
- set: Set,
- frozenset: FrozenSet,
- tuple: Tuple,
- }
- # Signature for functions thattransforms the body (`list[str]`) of the
- # generated code
- TransformCodeFunc = Callable[[List[str]], List[str]]
- class _CustomBuiltin(NamedTuple):
- """Additional objs that we add to every graph's globals.
- The repr() for some standard library objects is not valid Python code without
- an import. For common objects of this sort, we bundle them in the globals of
- every FX graph.
- """
- # How to import this object from the standard library.
- import_str: str
- # The actual object, produced from that import string.
- obj: Any
- _custom_builtins: Dict[str, _CustomBuiltin] = {}
- def _register_custom_builtin(name: str, import_str: str, obj: Any):
- _custom_builtins[name] = _CustomBuiltin(import_str, obj)
- _register_custom_builtin('inf', 'from math import inf', math.inf)
- _register_custom_builtin('nan', 'from math import nan', math.nan)
- _register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
- _register_custom_builtin('torch', 'import torch', torch)
- _register_custom_builtin('device', 'from torch import device', torch.device)
- _register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
- _register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
- def _is_magic(x: str) -> bool:
- return x.startswith('__') and x.endswith('__')
- def _snake_case(s: str) -> str:
- """
- Transforms the given string ``s`` to a Python-style variable name
- Examples:
- ``mod.snake_case`` -> ``mod.snake_case``
- ``mod.pascalCase``-> ``mod.pascal_case``
- ``mod.ALL_CAPS`` -> ``mod.all_caps``
- """
- chars = []
- prev_lower = False
- for c in s:
- if prev_lower and c.isupper():
- chars.append('_')
- chars.append(c.lower())
- prev_lower = c.islower()
- return ''.join(chars)
- def _is_from_torch(obj: Any) -> bool:
- module_name = getattr(obj, '__module__', None)
- if module_name is not None:
- base_module = module_name.partition('.')[0]
- return (
- base_module == 'torch' and
- not module_name.startswith("torch._dynamo.") and
- not module_name.startswith("torch._inductor.")
- )
- name = getattr(obj, '__name__', None)
- # exclude torch because torch.torch.torch.torch works. idk mang
- if name is not None and name != 'torch':
- for guess in [torch, torch.nn.functional]:
- if getattr(guess, name, None) is obj:
- return True
- return False
- class _Namespace:
- """A context for associating names uniquely with objects.
- The following invariants are enforced:
- - Each object gets a single name.
- - Each name is unique within a given namespace.
- - Names generated do not shadow builtins, unless the object is indeed that builtin.
- """
- def __init__(self):
- self._obj_to_name: Dict[Any, str] = {}
- self._unassociated_names = set()
- self._used_names: Set[str] = set()
- self._base_count: Dict[str, int] = defaultdict(int)
- self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
- self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
- def create_name(self, candidate: str, obj: Optional[Any]) -> str:
- """Create a unique name.
- Arguments:
- candidate: used as the basis for the unique name, relevant to the user.
- obj: If not None, an object that will be associated with the unique name.
- """
- if obj is not None and obj in self._obj_to_name:
- return self._obj_to_name[obj]
- # delete all characters that are illegal in a Python identifier
- candidate = self._illegal_char_regex.sub('_', candidate)
- if not candidate:
- candidate = '_unnamed'
- if candidate[0].isdigit():
- candidate = f'_{candidate}'
- match = self._name_suffix_regex.match(candidate)
- if match is None:
- base = candidate
- num = None
- else:
- base, num_str = match.group(1, 2)
- num = int(num_str)
- candidate = base if num is None else f'{base}_{num}'
- if not num:
- num = self._base_count[base]
- while candidate in self._used_names or self._is_illegal_name(candidate, obj):
- num += 1
- candidate = f'{base}_{num}'
- self._used_names.add(candidate)
- self._base_count[base] = num
- if obj is None:
- self._unassociated_names.add(candidate)
- else:
- self._obj_to_name[obj] = candidate
- return candidate
- def associate_name_with_obj(self, name: str, obj: Any):
- """Associate a unique name with an object.
- Neither `name` nor `obj` should be associated already.
- """
- assert obj not in self._obj_to_name
- assert name in self._unassociated_names
- self._obj_to_name[obj] = name
- self._unassociated_names.remove(name)
- def _is_illegal_name(self, name: str, obj: Any) -> bool:
- # 1. keywords are never allowed as names.
- if name in keyword.kwlist:
- return True
- # 2. Can't shadow a builtin name, unless you *are* that builtin.
- if name in builtins.__dict__:
- return obj is not builtins.__dict__[name]
- # 3. Can't shadow our custom builtins either
- if name in _custom_builtins:
- return obj is not _custom_builtins[name].obj
- return False
- def _rename_object(self, obj: Any, name: str):
- assert obj in self._obj_to_name
- self._obj_to_name[obj] = name
- self._used_names.add(name)
- dtype_abbrs = {
- torch.bfloat16: 'bf16',
- torch.float64: 'f64',
- torch.float32: 'f32',
- torch.float16: 'f16',
- torch.float8_e4m3fn: 'f8e4m3fn',
- torch.float8_e5m2: 'f8e5m2',
- torch.float8_e4m3fnuz: 'f8e4m3fnuz',
- torch.float8_e5m2fnuz: 'f8e5m2fnuz',
- torch.complex32: 'c32',
- torch.complex64: 'c64',
- torch.complex128: 'c128',
- torch.int8: 'i8',
- torch.int16: 'i16',
- torch.int32: 'i32',
- torch.int64: 'i64',
- torch.bool: 'b8',
- torch.uint8: 'u8',
- torch.uint32: 'u32',
- torch.uint64: 'u64',
- }
- @compatibility(is_backward_compatible=True)
- @dataclass
- class PythonCode:
- """
- Represents all the information necessary to exec or save a graph as Python code.
- """
- # Python source code for the forward function definition.
- src: str
- # Values in global scope during execution of `src_def`.
- globals: Dict[str, Any]
- # Optional mapping from the forward function's line number to
- # node index.
- _lineno_map: Optional[Dict[int, Optional[int]]]
- def _format_target(base: str, target: str) -> str:
- elems = target.split('.')
- r = base
- for e in elems:
- if not e.isidentifier():
- r = f'getattr({r}, "{e}")'
- else:
- r = f'{r}.{e}'
- return r
- class _InsertPoint:
- def __init__(self, graph, new_insert):
- self.graph = graph
- self.orig_insert, graph._insert = graph._insert, new_insert
- def __enter__(self):
- pass
- def __exit__(self, type, value, tb):
- self.graph._insert = self.orig_insert
- class _node_list:
- def __init__(self, graph: 'Graph', direction: str = '_next'):
- assert direction in ['_next', '_prev']
- self.graph = graph
- self.direction = direction
- def __len__(self):
- return self.graph._len
- def __iter__(self):
- root = self.graph._root
- if self.direction == "_next":
- cur = root._next
- while cur is not root:
- if not cur._erased:
- yield cur
- cur = cur._next
- else:
- assert self.direction == "_prev"
- cur = root._prev
- while cur is not root:
- if not cur._erased:
- yield cur
- cur = cur._prev
- def __reversed__(self):
- return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
- class _PyTreeInfo(NamedTuple):
- """
- Contains extra info stored when we're using Pytrees
- """
- orig_args: List[str]
- in_spec: pytree.TreeSpec
- out_spec: Optional[pytree.TreeSpec]
- @dataclass(frozen=True)
- class _ParsedStackTrace:
- """
- Represents the top-most frame of a parsed stack trace
- """
- file: str
- lineno: str
- name: str
- code: str
- def get_summary_str(self):
- return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}'
- # get File:lineno code from stack_trace
- def _parse_stack_trace(stack_trace: str):
- if stack_trace is None:
- return None
- pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
- lines = stack_trace.strip().split('\n')
- # stacktrace should have innermost frame last, so we
- # iterate backwards to find the first line that starts
- # with 'File '
- summary_str = ""
- for idx in range(len(lines) - 2, -1, -1):
- line = lines[idx].strip()
- matches = pattern.match(line)
- if matches:
- file = matches.group(1)
- lineno = matches.group(2)
- name = matches.group(3)
- # next line should be the code
- code = lines[idx + 1].strip()
- return _ParsedStackTrace(file, lineno, name, code)
- return None
- @compatibility(is_backward_compatible=False)
- class CodeGen:
- def __init__(self):
- self._body_transformer: Optional[TransformCodeFunc] = None
- self._func_name: str = "forward"
- def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
- """
- Given the free variables and a return annotation, generates the beginning of the FX function.
- By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
- """
- # If the original function didn't have self as its first argument, we
- # would have added it.
- if len(free_vars) == 0 or free_vars[0] != 'self':
- free_vars.insert(0, 'self')
- return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
- def generate_output(self, output_args: Argument) -> str:
- """
- Given the output arguments, generates the return statement of the FX function.
- Note: The returned statement should not be indented.
- """
- return f'return {repr(output_args)}'
- def process_inputs(self, *args: Any) -> Any:
- """
- Transforms the inputs so that the graph can take them as arguments, as
- non-default codegen may result in the inputs to the function being
- different from the inputs to the graph.
- If the graph was directly runnable, this invariant should hold true
- `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
- """
- return args
- def process_outputs(self, outputs: Any) -> Any:
- """
- Transforms the outputs of the graph to be identical to the codegen.
- See ``process_inputs`` for more details.
- """
- return outputs
- def additional_globals(self) -> List[Tuple[str, Any]]:
- """
- If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
- For example, return ['List', typing.List] if you need ``List`` in the global context.
- """
- return []
- def _gen_python_code(
- self, nodes, root_module: str, namespace: _Namespace, *,
- verbose: bool = False, include_stride: bool = False, include_device: bool = False
- ) -> PythonCode:
- free_vars: List[str] = []
- body: List[str] = []
- globals_: Dict[str, Any] = {}
- wrapped_fns: Dict[str, None] = {}
- # Wrap string in list to pass by reference
- maybe_return_annotation : List[str] = ['']
- include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1")
- include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1")
- def add_global(name_hint: str, obj: Any):
- """Add an obj to be tracked as a global.
- We call this for names that reference objects external to the
- Graph, like functions or types.
- Returns: the global name that should be used to reference 'obj' in generated source.
- """
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
- # HACK: workaround for how torch custom ops are registered. We
- # can't import them like normal modules so they must retain their
- # fully qualified name.
- return _get_qualified_name(obj)
- # normalize the name hint to get a proper identifier
- global_name = namespace.create_name(name_hint, obj)
- if global_name in globals_:
- assert globals_[global_name] is obj
- return global_name
- globals_[global_name] = obj
- return global_name
- # Pre-fill the globals table with registered builtins.
- for name, (_, obj) in _custom_builtins.items():
- add_global(name, obj)
- def type_repr(o : Any):
- if o == ():
- # Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
- typename = _type_repr(o)
- if hasattr(o, '__origin__'):
- # This is a generic type, e.g. typing.List[torch.Tensor]
- origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
- origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
- # Assign global names for each of the inner type variables.
- args = [type_repr(arg) for arg in o.__args__]
- if len(args) == 0:
- # Bare type, such as `typing.Tuple` with no subscript
- # This code-path used in Python < 3.9
- return origin_typename
- return f'{origin_typename}[{",".join(args)}]'
- else:
- # Bare type, such as `typing.Tuple` with no subscript
- # This code-path used in Python 3.9+
- return origin_typename
- # Common case: this is a regular module name like 'foo.bar.baz'
- return add_global(typename, o)
- def _get_repr(arg: Any) -> str:
- # Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
- qualified_name = _get_qualified_name(type(arg))
- global_name = add_global(qualified_name, type(arg))
- return f"{global_name}{repr(tuple(arg))}"
- elif isinstance(arg, torch._ops.OpOverload):
- qualified_name = _get_qualified_name(arg)
- global_name = add_global(qualified_name, arg)
- return f"{global_name}"
- elif isinstance(arg, enum.Enum):
- cls = arg.__class__
- clsname = add_global(cls.__name__, cls)
- return f"{clsname}.{arg.name}"
- return repr(arg)
- def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
- if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
- return args_s or kwargs_s
- # Run through reverse nodes and record the first instance of a use
- # of a given node. This represents the *last* use of the node in the
- # execution order of the program, which we will use to free unused
- # values
- node_to_last_use : Dict[Node, Node] = {}
- user_to_last_uses : Dict[Node, List[Node]] = {}
- def register_last_uses(n : Node, user : Node):
- if n not in node_to_last_use:
- node_to_last_use[n] = user
- user_to_last_uses.setdefault(user, []).append(n)
- for node in reversed(nodes):
- map_arg(node.args, lambda n: register_last_uses(n, node))
- map_arg(node.kwargs, lambda n: register_last_uses(n, node))
- def delete_unused_values(user : Node):
- """
- Delete values after their last use. This ensures that values that are
- not used in the remainder of the code are freed and the memory usage
- of the code is optimal.
- """
- if user.op == 'placeholder':
- return
- if user.op == 'output':
- body.append('\n')
- return
- nodes_to_delete = user_to_last_uses.get(user, [])
- if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
- else:
- body.append('\n')
- prev_stacktrace = None
- def append_stacktrace_summary(node : Node):
- """
- Append a summary of the stacktrace to the generated code. This is
- useful for debugging.
- """
- nonlocal prev_stacktrace
- if node.op not in {'placeholder', 'output'}:
- if node.stack_trace:
- if node.stack_trace != prev_stacktrace:
- prev_stacktrace = node.stack_trace
- summary_str = ""
- if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
- summary_str = parsed_stack_trace.get_summary_str()
- body.append(f'\n# {summary_str}\n')
- elif prev_stacktrace != "":
- prev_stacktrace = ""
- body.append('\n# No stacktrace found for following nodes\n')
- def stringify_shape(shape : Iterable) -> str:
- return f"[{', '.join(str(x) for x in shape)}]"
- def emit_node(node : Node):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if verbose:
- # override annotation with more detailed information
- from torch._subclasses.fake_tensor import FakeTensor
- from torch.fx.experimental.proxy_tensor import py_sym_types
- from torch.fx.passes.shape_prop import TensorMetadata
- meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None)))
- # use string as annotation, to make it valid python code
- if isinstance(meta_val, FakeTensor):
- stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else ""
- device_annotation = f"{meta_val.device}" if include_device else ""
- maybe_type_annotation = \
- f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' \
- f'{stride_annotation}{device_annotation}"'
- elif isinstance(meta_val, py_sym_types):
- maybe_type_annotation = f': "Sym({meta_val})"'
- elif isinstance(meta_val, TensorMetadata):
- maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
- if node.op == 'placeholder':
- assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
- if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
- return
- elif node.op == 'call_method':
- assert isinstance(node.target, str)
- body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
- return
- elif node.op == 'call_function':
- assert callable(node.target)
- # pretty print operators
- if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
- assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
- return
- # pretty print inplace operators; required for jit.script to work properly
- # not currently supported in normal FX graphs, but generated by torchdynamo
- if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
- return
- qualified_name = _get_qualified_name(node.target)
- global_name = add_global(qualified_name, node.target)
- # special case for getattr: node.args could be 2-argument or 3-argument
- # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
- return
- body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
- wrapped_fns.setdefault(global_name)
- return
- elif node.op == 'call_module':
- assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
- return
- elif node.op == 'get_attr':
- assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
- return
- elif node.op == 'output':
- if node.type is not None:
- maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
- body.append(self.generate_output(node.args[0]))
- return
- raise NotImplementedError(f'node: {node.op} {node.target}')
- for i, node in enumerate(nodes):
- # NOTE: emit_node does not emit a string with newline. It depends
- # on delete_unused_values to append one
- if verbose:
- append_stacktrace_summary(node)
- # emit a counter comment to keep track of
- # node index, which will be deleted later
- # after going through _body_transformer
- body.append(f"# COUNTER: {i}\n")
- emit_node(node)
- delete_unused_values(node)
- if len(body) == 0:
- # If the Graph has no non-placeholder nodes, no lines for the body
- # have been emitted. To continue to have valid Python code, emit a
- # single pass statement
- body.append('pass\n')
- if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
- else:
- wrap_stmts = ''
- if self._body_transformer:
- body = self._body_transformer(body)
- for name, value in self.additional_globals():
- add_global(name, value)
- prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- # remove counter and generate lineno to node index mapping
- lineno_map: Dict[int, Optional[int]] = {}
- prologue_len = prologue.count('\n') + 1
- new_lines: List[str] = []
- cur_idx = None
- for line in ''.join(body).split('\n'):
- counter = re.search(r"# COUNTER: (\d+)", line)
- if counter and counter.group(1) is not None:
- cur_idx = int(counter.group(1))
- else:
- lineno_map[len(new_lines) + prologue_len] = cur_idx
- new_lines.append(line)
- code = "\n".join(new_lines).lstrip('\n')
- code = '\n'.join(' ' + line for line in code.split('\n'))
- fn_code = f"""
- {wrap_stmts}
- {prologue}
- {code}"""
- return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
- # Ideally, we'd like to refactor all of the pytree logic into this codegen
- # class. Unfortunately, there are 3 areas we currently need extra logic in FX.
- # 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
- # 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
- # Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
- # 3. We currently can't register the pytree imports with `add_global` - not sure why.
- class _PyTreeCodeGen(CodeGen):
- def __init__(self, pytree_info: _PyTreeInfo):
- super().__init__()
- self.pytree_info: _PyTreeInfo = pytree_info
- def process_inputs(self, *inputs: Any) -> Any:
- flat_args = pytree.arg_tree_leaves(*inputs)
- return flat_args
- def process_outputs(self, out: Any) -> Any:
- if self.pytree_info is None or self.pytree_info.out_spec is None:
- return out
- if not isinstance(out, (list, tuple)):
- out = [out]
- assert self.pytree_info.out_spec is not None
- return pytree.tree_unflatten(out, self.pytree_info.out_spec)
- def gen_fn_def(self, free_vars, maybe_return_annotation):
- # Given a user function/model:
- # myargs = [myargs0, myargs1]
- # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
- # def forward(self, mypos, *myargs, mykey=None, **mykwargs):
- #
- # The generated code flattens all keywords into positional arguments for `forward()`
- # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
- #
- # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
- # e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
- # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
- # self._in_spec)
- #
- # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
- # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
- if self.pytree_info is None:
- return super().gen_fn_def(free_vars, maybe_return_annotation)
- fn_args = self.pytree_info.orig_args
- has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
- if has_orig_self:
- free_vars.insert(0, 'self')
- fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
- if len(free_vars) > 0: # pytree has placeholders in it
- # when kwargs is present, in_spec is tuple(args, kwargs)
- has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
- self.pytree_info.in_spec.num_children == 2 and \
- self.pytree_info.in_spec.children_specs[0].type == tuple and \
- self.pytree_info.in_spec.children_specs[1].type == dict
- fn_kwargs = '{}'
- fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
- if has_args_kwargs_tuple:
- count_args = self.pytree_info.in_spec.children_specs[0].num_children
- fn_args = self.pytree_info.orig_args[:count_args]
- fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
- self.pytree_info.in_spec.children_specs[1].context,
- self.pytree_info.orig_args[count_args:])) + '}'
- fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
- # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
- # we need to split it to two lines:
- # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
- # one for code: `var1, var2, = function_call()`
- without_annotation = [x.split(":")[0] for x in free_vars]
- has_annotation = [x + "; " for x in free_vars if ":" in x]
- if len(has_annotation) > 0:
- fn_definition += "\n " + "".join(has_annotation) + "\n"
- fn_definition += f"""
- {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
- return fn_definition
- def generate_output(self, output_args):
- if self.pytree_info and self.pytree_info.out_spec:
- return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
- else:
- return super().generate_output(output_args)
- class _FindNodesLookupTable:
- """
- Side table for the graph for the purpose of doing fast queries
- """
- def __init__(self):
- self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict)
- def _key(self, node) -> Tuple[str, Optional[Target]]:
- return (node.op, node.target if node.op == "call_function" else None)
- def __contains__(self, node) -> bool:
- return node in self.table[self._key(node)]
- def insert(self, node: Node) -> None:
- self.table[self._key(node)][node] = None
- def remove(self, node: Node) -> None:
- self.table[self._key(node)].pop(node)
- def find_nodes(self, *, op: str, target: Optional['Target'] = None):
- if op == "call_function":
- assert target is not None
- return dict(self.table[(op, target)]).keys()
- if target is None:
- return dict(self.table[(op, None)]).keys()
- # op is call_method, get_attr, call_module
- return [node for node in self.table[(op, None)].keys() if node.target == target]
- @compatibility(is_backward_compatible=True)
- class Graph:
- """
- ``Graph`` is the main data structure used in the FX Intermediate Representation.
- It consists of a series of ``Node`` s, each representing callsites (or other
- syntactic constructs). The list of ``Node`` s, taken together, constitute a
- valid Python function.
- For example, the following code
- .. code-block:: python
- import torch
- import torch.fx
- class MyModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 4))
- self.linear = torch.nn.Linear(4, 5)
- def forward(self, x):
- return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
- m = MyModule()
- gm = torch.fx.symbolic_trace(m)
- Will produce the following Graph::
- print(gm.graph)
- .. code-block:: text
- graph(x):
- %linear_weight : [num_users=1] = self.linear.weight
- %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
- %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
- %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
- %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
- %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
- return topk_1
- For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
- """
- @compatibility(is_backward_compatible=True)
- def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
- tracer_extras: Optional[Dict[str, Any]] = None):
- """
- Construct an empty Graph.
- """
- self._root : Node = Node(self, '', 'root', '', (), {})
- self._used_names : Dict[str, int] = {} # base name -> number
- self._insert = self._root.prepend
- self._len = 0
- self._graph_namespace = _Namespace()
- self._owning_module = owning_module
- self._tracer_cls = tracer_cls
- self._tracer_extras = tracer_extras
- self._codegen = CodeGen()
- self._co_fields : Dict[str, Any] = {}
- self._find_nodes_lookup_table = _FindNodesLookupTable()
- @property
- def owning_module(self):
- return self._owning_module
- @owning_module.setter
- def owning_module(self, mod: Optional["GraphModule"]):
- self._owning_module = mod
- @property
- def nodes(self) -> _node_list:
- """
- Get the list of Nodes that constitute this Graph.
- Note that this ``Node`` list representation is a doubly-linked list. Mutations
- during iteration (e.g. delete a Node, add a Node) are safe.
- Returns:
- A doubly-linked list of Nodes. Note that ``reversed`` can be called on
- this list to switch iteration order.
- """
- return _node_list(self)
- @compatibility(is_backward_compatible=False)
- def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True):
- """
- Allows for fast query of nodes
- Args:
- op (str): the name of the operation
- target (Optional[Target]): the target of the node. For call_function,
- the target is required. For other ops, the target is optional.
- sort (bool): whether to return nodes in the order they appear on
- on the graph.
- Returns:
- Iteratable of nodes with the requested op and target.
- """
- node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target)
- if sort:
- return sorted(node_list)
- return node_list
- @compatibility(is_backward_compatible=True)
- def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
- """
- Copy all nodes from a given graph into ``self``.
- Args:
- g (Graph): The source graph from which to copy Nodes.
- val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
- from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
- in with values in it already to override copying of certain values.
- Returns:
- The value in ``self`` that is now equivalent to the output value in ``g``,
- if ``g`` had an ``output`` node. ``None`` otherwise.
- """
- for node in g.nodes:
- if node in val_map:
- continue
- if node.op == 'output':
- rv = map_arg(node.args[0], lambda n: val_map[n])
- return rv if not return_output_node else (rv, node)
- val_map[node] = self.node_copy(node, lambda n : val_map[n])
- return None
- def __deepcopy__(self, memo=None) -> 'Graph':
- """
- Explicitly implement __deepcopy__ to prevent excessive recursion depth
- from the default implementation. This uses graph_copy to copy the nodes
- in an iterative way, rather than recursive. It also populates the
- memoization table to prevent unnecessary copies (e.g. references to
- nodes or other parts of the Graph from a custom GraphModule implementation.
- """
- memo = memo if memo else {}
- g = Graph(tracer_cls=self._tracer_cls)
- output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
- g._codegen = copy.deepcopy(self._codegen)
- assert isinstance(output_vals, tuple)
- output_val, old_output_node = output_vals
- new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None))
- new_output_node.meta = copy.copy(old_output_node.meta)
- return g
- @compatibility(is_backward_compatible=True)
- def create_node(self, op: str, target: 'Target',
- args: Optional[Tuple['Argument', ...]] = None,
- kwargs: Optional[Dict[str, 'Argument']] = None,
- name: Optional[str] = None,
- type_expr: Optional[Any] = None) -> Node:
- """
- Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
- Note that the current insert-point can be set via :meth:`Graph.inserting_before`
- and :meth:`Graph.inserting_after`.
- Args:
- op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
- 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
- described in the ``Graph`` docstring.
- args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
- kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
- name (Optional[str]): an optional string name for the ``Node``.
- This will influence the name of the value assigned to in the
- Python generated code.
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- Returns:
- The newly-created and inserted node.
- """
- assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
- args = () if args is None else args
- kwargs = {} if kwargs is None else kwargs
- assert isinstance(args, tuple), "args must be a tuple"
- assert isinstance(kwargs, dict), "kwargs must be a dict"
- candidate = name if name is not None else self._target_to_str(target)
- name = self._graph_namespace.create_name(candidate, None)
- n = Node(self, name, op, target, args, kwargs, type_expr)
- if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None:
- for f in self.owning_module._create_node_hooks:
- f(n)
- self._graph_namespace.associate_name_with_obj(name, n)
- self._insert(n)
- self._find_nodes_lookup_table.insert(n)
- self._len += 1
- return n
- @compatibility(is_backward_compatible=False)
- def process_inputs(self, *args):
- """
- Processes args so that they can be passed to the FX graph.
- """
- return self._codegen.process_inputs(*args)
- @compatibility(is_backward_compatible=False)
- def process_outputs(self, out):
- return self._codegen.process_outputs(out)
- @compatibility(is_backward_compatible=True)
- def erase_node(self, to_erase : Node) -> None:
- """
- Erases a ``Node`` from the ``Graph``. Throws an exception if
- there are still users of that node in the ``Graph``.
- Args:
- to_erase (Node): The ``Node`` to erase from the ``Graph``.
- """
- if len(to_erase.users) > 0:
- raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
- f'users in the graph: {to_erase.users}!')
- if to_erase.graph != self:
- raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
- if to_erase._erased:
- warnings.warn(f"erase_node({to_erase}) on an already erased node")
- return
- if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None:
- for f in self.owning_module._erase_node_hooks:
- f(to_erase)
- self._find_nodes_lookup_table.remove(to_erase)
- to_erase._remove_from_list()
- to_erase._erased = True # iterators may retain handles to erased nodes
- self._len -= 1
- # Null out this Node's argument nodes so that the Nodes referred to
- # can update their ``users`` accordingly
- new_args = map_arg(to_erase.args, lambda n: None)
- assert isinstance(new_args, tuple)
- to_erase.args = new_args
- new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
- assert isinstance(new_kwargs, dict)
- to_erase.kwargs = new_kwargs
- @compatibility(is_backward_compatible=True)
- def inserting_before(self, n: Optional[Node] = None):
- """Set the point at which create_node and companion methods will insert into the graph.
- When used within a 'with' statement, this will temporary set the insert point and
- then restore it when the with statement exits::
- with g.inserting_before(n):
- ... # inserting before node n
- ... # insert point restored to what it was previously
- g.inserting_before(n) # set the insert point permanently
- Args:
- n (Optional[Node]): The node before which to insert. If None this will insert before
- the beginning of the entire graph.
- Returns:
- A resource manager that will restore the insert point on ``__exit__``.
- """
- if n is None:
- return self.inserting_after(self._root)
- assert n.graph == self, "Node to insert before is not in graph."
- return _InsertPoint(self, n.prepend)
- @compatibility(is_backward_compatible=True)
- def inserting_after(self, n: Optional[Node] = None):
- """Set the point at which create_node and companion methods will insert into the graph.
- When used within a 'with' statement, this will temporary set the insert point and
- then restore it when the with statement exits::
- with g.inserting_after(n):
- ... # inserting after node n
- ... # insert point restored to what it was previously
- g.inserting_after(n) # set the insert point permanently
- Args:
- n (Optional[Node]): The node before which to insert. If None this will insert after
- the beginning of the entire graph.
- Returns:
- A resource manager that will restore the insert point on ``__exit__``.
- """
- if n is None:
- return self.inserting_before(self._root)
- assert n.graph == self, "Node to insert after is not in graph."
- return _InsertPoint(self, n.append)
- @compatibility(is_backward_compatible=True)
- def placeholder(self, name: str, type_expr: Optional[Any] = None,
- default_value : Any = inspect.Signature.empty) -> Node:
- """
- Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
- a function input.
- Args:
- name (str): A name for the input value. This corresponds to the name
- of the positional argument to the function this ``Graph`` represents.
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have. This is needed in some
- cases for proper code generation (e.g. when the function is used
- subsequently in TorchScript compilation).
- default_value (Any): The default value this function argument should take
- on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
- should be passed as this argument to specify that the parameter does _not_
- have a default value.
- .. note::
- The same insertion point and type expression rules apply for this method
- as ``Graph.create_node``.
- """
- args = () if default_value is inspect.Signature.empty else (default_value,)
- return self.create_node('placeholder', name, args=args, type_expr=type_expr)
- @compatibility(is_backward_compatible=True)
- def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
- """
- Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
- fetch of an attribute from the ``Module`` hierarchy.
- Args:
- qualified_name (str): the fully-qualified name of the attribute to be retrieved.
- For example, if the traced Module has a submodule named ``foo``, which has a
- submodule named ``bar``, which has an attribute named ``baz``, the qualified
- name ``foo.bar.baz`` should be passed as ``qualified_name``.
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- Returns:
- The newly-created and inserted ``get_attr`` node.
- .. note::
- The same insertion point and type expression rules apply for this method
- as ``Graph.create_node``.
- """
- def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
- module_path, _, name = qualified_name.rpartition(".")
- try:
- submod: torch.nn.Module = mod.get_submodule(module_path)
- except AttributeError:
- warnings.warn(f"Failed to fetch module {module_path}!")
- return False
- if not hasattr(submod, name):
- return False
- res = getattr(submod, name)
- if (not isinstance(res, torch.nn.Module)
- and not isinstance(res, torch.nn.Parameter)
- and name not in submod._buffers):
- return False
- return True
- if (self.owning_module and
- not _get_attr_reference_exists(self.owning_module, qualified_name)):
- warnings.warn("Attempted to insert a get_attr Node with no "
- "underlying reference in the owning "
- "GraphModule! Call "
- "GraphModule.add_submodule to add the "
- "necessary submodule, "
- "GraphModule.add_parameter to add the "
- "necessary Parameter, or "
- "nn.Module.register_buffer to add the "
- "necessary buffer", stacklevel=2)
- return self.create_node('get_attr', qualified_name, type_expr=type_expr)
- @compatibility(is_backward_compatible=True)
- def call_module(self,
- module_name: str,
- args: Optional[Tuple['Argument', ...]] = None,
- kwargs: Optional[Dict[str, 'Argument']] = None,
- type_expr: Optional[Any] = None) -> Node:
- """
- Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
- represents a call to the forward() function of a ``Module`` in the ``Module``
- hierarchy.
- Args:
- module_name (str): The qualified name of the ``Module`` in the ``Module``
- hierarchy to be called. For example, if the traced ``Module`` has a
- submodule named ``foo``, which has a submodule named ``bar``, the
- qualified name ``foo.bar`` should be passed as ``module_name`` to
- call that module.
- args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
- to the called method. Note that this should *not* include a ``self`` argument.
- kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
- to the called method
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- Returns:
- The newly-created and inserted ``call_module`` node.
- .. note::
- The same insertion point and type expression rules apply for this method
- as :meth:`Graph.create_node`.
- """
- if (self.owning_module and
- self.owning_module.get_submodule(module_name) is None):
- warnings.warn("Attempted to insert a call_module Node with "
- "no underlying reference in the owning "
- "GraphModule! Call "
- "GraphModule.add_submodule to add the "
- "necessary submodule")
- return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
- @compatibility(is_backward_compatible=True)
- def call_method(self,
- method_name: str,
- args: Optional[Tuple['Argument', ...]] = None,
- kwargs: Optional[Dict[str, 'Argument']] = None,
- type_expr: Optional[Any] = None) -> Node:
- """
- Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
- represents a call to a given method on the 0th element of ``args``.
- Args:
- method_name (str): The name of the method to apply to the self argument.
- For example, if args[0] is a ``Node`` representing a ``Tensor``,
- then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
- args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
- to the called method. Note that this *should* include a ``self`` argument.
- kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
- to the called method
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- Returns:
- The newly created and inserted ``call_method`` node.
- .. note::
- The same insertion point and type expression rules apply for this method
- as :meth:`Graph.create_node`.
- """
- return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
- @compatibility(is_backward_compatible=True)
- def call_function(self,
- the_function: Callable[..., Any],
- args: Optional[Tuple['Argument', ...]] = None,
- kwargs: Optional[Dict[str, 'Argument']] = None,
- type_expr: Optional[Any] = None) -> Node:
- """
- Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
- represents a call to a Python callable, specified by ``the_function``.
- Args:
- the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
- operator, Python function, or member of the ``builtins`` or ``operator``
- namespaces.
- args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
- to the called function.
- kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
- to the called function
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- Returns:
- The newly created and inserted ``call_function`` node.
- .. note::
- The same insertion point and type expression rules apply for this method
- as :meth:`Graph.create_node`.
- """
- return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
- @compatibility(is_backward_compatible=True)
- def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
- """
- Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
- the graph of node to the graph of self. Example::
- # Copying all the nodes in `g` into `new_graph`
- g : torch.fx.Graph = ...
- new_graph = torch.fx.graph()
- value_remap = {}
- for node in g.nodes:
- value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
- Args:
- node (Node): The node to copy into ``self``.
- arg_transform (Callable[[Node], Argument]): A function that transforms
- ``Node`` arguments in node's ``args`` and ``kwargs`` into the
- equivalent argument in ``self``. In the simplest case, this should
- retrieve a value out of a table mapping Nodes in the original
- graph to ``self``.
- """
- args = map_arg(node.args, arg_transform)
- kwargs = map_arg(node.kwargs, arg_transform)
- assert isinstance(args, tuple)
- assert isinstance(kwargs, dict)
- result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
- result_node.meta = copy.copy(node.meta)
- return result_node
- @compatibility(is_backward_compatible=True)
- def output(self, result: 'Argument', type_expr: Optional[Any] = None):
- """
- Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
- a ``return`` statement in Python code. ``result`` is the value that should
- be returned.
- Args:
- result (Argument): The value to be returned.
- type_expr (Optional[Any]): an optional type annotation representing the
- Python type the output of this node will have.
- .. note::
- The same insertion point and type expression rules apply for this method
- as ``Graph.create_node``.
- """
- return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
- def _target_to_str(self, target : Target) -> str:
- if callable(target):
- op = target.__name__
- else:
- assert isinstance(target, str)
- op = target
- if _is_magic(op):
- op = op[2:-2]
- op = _snake_case(op)
- return op
- @compatibility(is_backward_compatible=True)
- def python_code(
- self, root_module: str, *,
- verbose: bool = False, include_stride: bool = False, include_device: bool = False
- ) -> PythonCode:
- """
- Turn this ``Graph`` into valid Python code.
- Args:
- root_module (str): The name of the root module on which to look-up
- qualified name targets. This is usually 'self'.
- Returns:
- A PythonCode object, consisting of two fields:
- src: the Python source code representing the object
- globals: a dictionary of global names in `src` -> the objects that they reference.
- """
- # NOTE: [Graph Namespaces]
- #
- # There are two types of symbols in generated Python source code:
- # locals and globals.
- # Locals are locally defined by the output of a node in the Graph.
- # Globals are references to external objects, like functions or types.
- #
- # When generating Python code, we need to make sure to name things
- # appropriately. In particular:
- # - All names should be unique, to avoid weird shadowing bugs.
- # - These names need to be consistent, e.g. a object should always be
- # referenced by the same name.
- #
- # To do this, we create a new namespace just for this source. All names
- # that get printed must come from this namespace.
- #
- # Why can't we re-use node.name? Because it was generated within the
- # namespace `self._graph_namespace`. In order to provide uniqueness
- # over both locals (node.name) *and* globals, we create a completely
- # new namespace to put all identifiers in.
- namespace = _Namespace()
- # Override Node's repr to generate a valid name within our namespace.
- # Since repr() is designed to produce a valid Python expression, it
- # makes sense to re-use it. This way, it's easy to print something like
- # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
- # implemented cooperatively to allow this.
- def node_repr(n: Node):
- return namespace.create_name(n.name, n)
- @contextmanager
- def override_node_repr(graph: Graph):
- orig_repr_fns = {}
- for node in graph.nodes:
- orig_repr_fns[node] = node._repr_fn
- node._repr_fn = node_repr
- try:
- yield None
- finally:
- # restore the original repr functions
- for node in graph.nodes:
- node._repr_fn = orig_repr_fns[node]
- with override_node_repr(self):
- return self._python_code(
- root_module, namespace,
- verbose=verbose, include_stride=include_stride, include_device=include_device
- )
- def _python_code(
- self, root_module: str, namespace: _Namespace, *,
- verbose: bool = False, include_stride: bool = False, include_device: bool = False
- ) -> PythonCode:
- return self._codegen._gen_python_code(
- self.nodes, root_module, namespace,
- verbose=verbose, include_stride=include_stride, include_device=include_device
- )
- def __str__(self) -> str:
- """
- Return a human-readable (not machine-readable) string representation
- of this Graph
- """
- placeholder_names : List[str] = []
- # This is a one-element array just so ``format_node`` can modify the closed
- # over value
- maybe_return_typename : List[str] = ['']
- node_strs = [node.format_node(placeholder_names) for node in self.nodes]
- param_str = ', '.join(placeholder_names)
- s = f'graph({param_str}){maybe_return_typename[0]}:'
- for node_str in node_strs:
- if node_str:
- s += '\n ' + node_str
- return s
- @compatibility(is_backward_compatible=True)
- def print_tabular(self):
- """
- Prints the intermediate representation of the graph in tabular
- format. Note that this API requires the ``tabulate`` module to be
- installed.
- """
- try:
- from tabulate import tabulate
- except ImportError:
- print("`print_tabular` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
- raise
- node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
- for n in self.nodes]
- print(tabulate(node_specs,
- headers=['opcode', 'name', 'target', 'args', 'kwargs']))
- @compatibility(is_backward_compatible=True)
- def lint(self):
- """
- Runs various checks on this Graph to make sure it is well-formed. In
- particular:
- - Checks Nodes have correct ownership (owned by this graph)
- - Checks Nodes appear in topological order
- - If this Graph has an owning GraphModule, checks that targets
- exist in that GraphModule
- """
- # Check topo order
- def check_arg(arg : Node, n : Optional[Node] = None) -> None:
- context_str = f' of Node \'{n}\' ' if n else ' '
- if arg.graph is not self:
- raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
- f'but was used as an argument! If you are copying nodes from another graph, make '
- f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
- if arg not in seen_values:
- raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
- f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
- seen_names : Set[str] = set()
- seen_values : Set[Node] = set()
- for node in self.nodes:
- if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
- raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
- if node.graph is not self:
- raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
- if node not in self._find_nodes_lookup_table:
- raise RuntimeError(f"Node '{node}' is not added to the side table")
- map_arg(node.args, lambda arg: check_arg(arg, node))
- map_arg(node.kwargs, lambda arg: check_arg(arg, node))
- seen_values.add(node)
- if node.name in seen_names:
- raise RuntimeError(f'Node redefined name {node.name}!')
- seen_names.add(node.name)
- # Check targets are legit
- if self.owning_module:
- for node in self.nodes:
- if node.op == 'call_function':
- if not callable(node.target):
- raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
- 'a Callable is expected')
- else:
- if not isinstance(node.target, str):
- raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
- 'a str is expected')
- if node.op in ['get_attr', 'call_module']:
- target_atoms = node.target.split('.')
- m_itr = self.owning_module
- for i, atom in enumerate(target_atoms):
- new_m_itr = getattr(m_itr, atom, None)
- seen_qualname = '.'.join(target_atoms[:i])
- if new_m_itr is None:
- raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
- f'{atom} of {seen_qualname}')
- if (node.op == "call_module"
- and not isinstance(new_m_itr, torch.nn.Module)):
- raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
- 'not reference an nn.Module')
- elif (node.op == "get_attr"
- and not isinstance(new_m_itr, torch.nn.Module)
- and not isinstance(new_m_itr, torch.nn.Parameter)
- and atom not in m_itr._buffers):
- warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
- 'not reference an nn.Module, nn.Parameter, or buffer, which is '
- 'what \'get_attr\' Nodes typically target')
- else:
- m_itr = new_m_itr
- @compatibility(is_backward_compatible=True)
- def eliminate_dead_code(self):
- """
- Remove all dead code from the graph, based on each node's number of
- users, and whether the nodes have any side effects. The graph must be
- topologically sorted before calling.
- Returns:
- bool: Whether the graph was changed as a result of the pass.
- Example:
- Before dead code is eliminated, `a` from `a = x + 1` below has no users
- and thus can be eliminated from the graph without having an effect.
- .. code-block:: python
- def forward(self, x):
- a = x + 1
- return x + self.attr_1
- After dead code is eliminated, `a = x + 1` has been removed, and the rest
- of `forward` remains.
- .. code-block:: python
- def forward(self, x):
- return x + self.attr_1
- .. warning::
- Dead code elimination has some heuristics to avoid removing
- side-effectful nodes (see Node.is_impure) but in general coverage
- is very bad, so you should assume that this method is not sound
- to call unless you know that your FX graph consists entirely
- of functional operations.
- """
- # Lint the graph first to make sure its topologically sorted, otherwise
- # DCE below will not behave as expected.
- self.lint()
- # Reverse iterate so that when we remove a node, any nodes used as an
- # input to that node have an updated user count that no longer reflects
- # the removed node.
- changed = False
- for node in reversed(self.nodes):
- if not node.is_impure() and len(node.users) == 0:
- self.erase_node(node)
- changed = True
- return changed
- @compatibility(is_backward_compatible=False)
- def set_codegen(self, codegen: CodeGen):
- self._codegen = codegen
- @compatibility(is_backward_compatible=False)
- def on_generate_code(
- self,
- make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]
- ):
- """Register a transformer function when python code is generated
- Args:
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
- a function that returns a code transformer to be registered.
- This function is called by `on_generate_code` to obtain the
- code transformer.
- This function is also given as its input the currently
- registered code transformer (or None if nothing is registered),
- in case it is not desirable to overwrite it. This is useful to
- chain code transformers together.
- Returns:
- a context manager that when used in a `with` statement, to automatically
- restore the previously registered code transformer.
- Example:
- .. code-block:: python
- gm: fx.GraphModule = ...
- # This is a code transformer we want to register. This code
- # transformer prepends a pdb import and trace statement at the very
- # beginning of the generated torch.fx code to allow for manual
- # debugging with the PDB library.
- def insert_pdb(body):
- return ["import pdb; pdb.set_trace()\\n", *body]
- # Registers `insert_pdb`, and overwrites the current registered
- # code transformer (given by `_` to the lambda):
- gm.graph.on_generate_code(
- lambda _: insert_pdb
- )
- # Or alternatively, registers a code transformer which first
- # runs `body` through existing registered transformer, then
- # through `insert_pdb`:
- gm.graph.on_generate_code(
- lambda current_trans: (
- lambda body: insert_pdb(
- current_trans(body) if current_trans
- else body
- )
- )
- )
- gm.recompile()
- gm(*inputs) # drops into pdb
- This function can also be used as a context manager, with the benefit to
- automatically restores the previously registered code transformer:
- .. code-block:: python
- # ... continue from previous example
- with gm.graph.on_generate_code(lambda _: insert_pdb):
- # do more stuff with `gm`...
- gm.recompile()
- gm(*inputs) # drops into pdb
- # now previous code transformer is restored (but `gm`'s code with pdb
- # remains - that means you can run `gm` with pdb here too, until you
- # run next `recompile()`).
- """
- on_gen_code_old = self._codegen._body_transformer
- self._codegen._body_transformer = make_transformer(on_gen_code_old)
- @contextlib.contextmanager
- def on_generate_code_context_manager():
- try:
- yield
- finally:
- self._codegen._body_transformer = on_gen_code_old
- return on_generate_code_context_manager()
- reflectable_magic_methods = {
- 'add': '{} + {}',
- 'sub': '{} - {}',
- 'mul': '{} * {}',
- 'floordiv': '{} // {}',
- 'truediv': '{} / {}',
- 'div': '{} / {}',
- 'mod': '{} % {}',
- 'pow': '{} ** {}',
- 'lshift': '{} << {}',
- 'rshift': '{} >> {}',
- 'and_': '{} & {}',
- 'or_': '{} | {}',
- 'xor': '{} ^ {}',
- 'getitem': '{}[{}]',
- 'matmul': '{} @ {}',
- }
- magic_methods = dict({
- 'eq': '{} == {}',
- 'ne': '{} != {}',
- 'lt': '{} < {}',
- 'gt': '{} > {}',
- 'le': '{} <= {}',
- 'ge': '{} >= {}',
- 'pos': '+{}',
- 'neg': '-{}',
- 'invert': '~{}'}, **reflectable_magic_methods)
- inplace_methods = {
- 'iadd': '{} += {}',
- 'iand': '{} &= {}',
- 'ifloordiv': '{} //= {}',
- 'ilshift': '{} <<= {}',
- 'imod': '{} %= {}',
- 'imul': '{} *= {}',
- 'imatmul': '{} @= {}',
- 'ior': '{} |= {}',
- 'ipow': '{} **= {}',
- 'irshift': '{} >>= {}',
- 'isub': '{} -= {}',
- 'itruediv': '{} /= {}',
- 'ixor': '{} ^= {}',
- 'setitem': '{}[{}] = {}',
- }
|