| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- # mypy: allow-untyped-defs
- import collections
- import dataclasses
- import re
- import sys
- import types
- from typing import Counter, Dict, List, Optional
- import torch.nn
- from . import utils
- from .bytecode_transformation import (
- create_call_function,
- create_call_method,
- create_dup_top,
- create_instruction,
- create_load_attr,
- create_load_global,
- create_load_method,
- create_rot_n,
- Instruction,
- )
- from .exc import unimplemented
- from .source import AttrSource, Source
- from .utils import is_safe_constant, rot_n_helper
- from .variables.base import VariableTracker
- from .variables.nn_module import NNModuleVariable
- from .variables.tensor import (
- NumpyNdarrayVariable,
- SymNodeVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .variables.torch_function import TensorWithTFOverrideVariable
- @dataclasses.dataclass
- class GraphOutputEntry:
- index: int
- variable: VariableTracker
- class PyCodegen:
- """
- Helper class uses for constructing Python bytecode
- """
- def __init__(
- self,
- tx=None,
- root: Optional[torch.nn.Module] = None,
- graph_output_var: Optional[str] = None,
- tempvars=None,
- ):
- self.root = root
- self.top_of_stack: Optional[VariableTracker] = None
- self.uses: Counter[VariableTracker] = collections.Counter()
- self.graph_outputs: Dict[int, GraphOutputEntry] = {}
- self._output: List[Instruction] = []
- self.tempvars = tempvars or {}
- self.tx = tx
- self.graph_output_var = graph_output_var
- self.code_options = self.tx.output.code_options
- self.cell_and_freevars = self.tx.cell_and_freevars
- self.new_var = self.tx.output.new_var
- self.mutable_side_effects_from_source = False
- self.value_from_source: bool = True
- def restore_stack(self, stack_values, *, value_from_source=True):
- prior = self.mutable_side_effects_from_source
- self.mutable_side_effects_from_source = True
- prev = self.value_from_source
- self.value_from_source &= value_from_source
- try:
- self.foreach(stack_values)
- finally:
- self.mutable_side_effects_from_source = prior
- self.value_from_source = prev
- def graph_output_vars(self):
- return [x.variable for x in self.graph_outputs.values()]
- def call_reconstruct(self, value):
- res = value.reconstruct(self)
- assert res is None, f"reconstruct!=None {value}"
- def __call__(self, value, allow_cache=True):
- """Generate code such that top-of-stack (TOS) is set to value"""
- if isinstance(value, Source):
- self.call_reconstruct(value)
- self.clear_tos()
- return
- assert isinstance(value, VariableTracker)
- output = self._output
- graph_outputs = self.graph_outputs
- if self.top_of_stack is value and allow_cache:
- output.append(create_dup_top())
- return
- if self.mutable_side_effects_from_source:
- # this is needed to get aliasing relationships right
- # value.mutable_local.source will get mutated to hold `value`
- # mutable_side_effects_from_source=False is used to codegen the mutation
- # mutable_side_effects_from_source=True is used to codegen a reference
- from .side_effects import MutableSideEffects
- if isinstance(value.mutable_local, MutableSideEffects):
- self(value.mutable_local.source)
- return
- if allow_cache:
- if value.mutable_local and value.mutable_local in self.tempvars:
- output.append(self.create_load(self.tempvars[value.mutable_local]))
- self.top_of_stack = value
- return
- if self.tempvars.get(value) is not None:
- output.append(self.create_load(self.tempvars[value]))
- self.top_of_stack = value
- return
- if value.source is not None and allow_cache and self.value_from_source:
- self.call_reconstruct(value.source)
- elif value.is_python_constant() and is_safe_constant(
- value.as_python_constant()
- ):
- output.append(self.create_load_const(value.as_python_constant()))
- elif isinstance(value, TensorWithTFOverrideVariable):
- graph_outputs_key = self.add_graph_output(value)
- self.load_import_from(utils.__name__, "to_subclass")
- self.load_graph_output(graph_outputs[graph_outputs_key].index)
- output.append(
- self.create_load_global(
- value.global_mangled_class_name(self.tx), False, add=True
- )
- )
- output.extend(create_call_function(2, True))
- elif (
- isinstance(value, SymNodeVariable)
- and value.python_type() == float
- and not self.tx.export
- ):
- # This is a little unusual; force the output convention to be a
- # Tensor here. Don't do this for export because this is
- # apparently load bearing for export tests (but I am a bit
- # doubtful it actually works in the real world)
- # NB: It works to add_graph_output on a computed expression
- # as_tensor here, because we memoize as_tensor calls on
- # SymNodeVariable!
- graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx))
- self.load_graph_output(graph_outputs[graph_outputs_key].index)
- output.extend(
- [self.create_load_attr("item")] + create_call_function(0, True)
- )
- elif isinstance(
- value,
- (
- TensorVariable,
- SymNodeVariable,
- UnspecializedPythonVariable,
- NumpyNdarrayVariable,
- ),
- ):
- graph_outputs_key = self.add_graph_output(value)
- if isinstance(value, NumpyNdarrayVariable):
- self.load_import_from(utils.__name__, "to_numpy_helper")
- self.load_graph_output(graph_outputs[graph_outputs_key].index)
- if isinstance(value, NumpyNdarrayVariable):
- output.extend(create_call_function(1, True))
- elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
- output.extend(
- [self.create_load_attr("item")] + create_call_function(0, True)
- )
- elif isinstance(value, NNModuleVariable):
- parts = value.module_key.split(".")
- if parts[0] in self.code_options["co_varnames"]:
- output.append(self.create_load(parts[0]))
- parts = parts[1:]
- else:
- assert self.root is not None
- output.append(self.create_load_output(self.root))
- for part in parts:
- output.append(self.create_load_attr(part))
- else:
- self.uses[value] += 1
- try:
- self.call_reconstruct(value)
- except NotImplementedError:
- unimplemented(f"reconstruct: {value}")
- if allow_cache and value in self.tempvars:
- self._output.append(create_dup_top())
- self.add_cache(value)
- self.top_of_stack = value
- def add_graph_output(self, value):
- graph_outputs_key = id(value.as_proxy())
- if graph_outputs_key not in self.graph_outputs:
- self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
- len(self.graph_outputs), value
- )
- return graph_outputs_key
- def load_graph_output(self, index):
- output = self._output
- output.append(self.create_load(self.graph_output_var))
- output.append(self._create_load_const(index))
- output.append(create_instruction("BINARY_SUBSCR"))
- def add_cache(self, value):
- var = self.new_var()
- self.tempvars[value] = var
- if value.mutable_local:
- self.tempvars[value.mutable_local] = var
- self._output.append(self.create_store(var))
- def foreach(self, items):
- for i in items:
- self(i)
- def setup_globally_cached(self, name, value, push_null):
- """Store value in a new global"""
- name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
- f_globals = self.tx.f_globals
- if name in f_globals:
- assert id(f_globals[name]) == id(value)
- else:
- f_globals[name] = value
- return [self.create_load_global(name, push_null, add=True)]
- def clear_tos(self):
- self.top_of_stack = None
- def append_output(self, inst):
- assert isinstance(inst, Instruction)
- self._output.append(inst)
- self.clear_tos()
- def extend_output(self, insts):
- assert all(isinstance(x, Instruction) for x in insts)
- self._output.extend(insts)
- self.clear_tos()
- def get_instructions(self) -> List[Instruction]:
- return self._output
- def create_load(self, name) -> Instruction:
- if name in self.cell_and_freevars():
- return create_instruction("LOAD_DEREF", argval=name)
- assert name in self.code_options["co_varnames"], f"{name} missing"
- return create_instruction("LOAD_FAST", argval=name)
- def create_load_closure(self, name) -> Instruction:
- assert name in self.cell_and_freevars()
- return create_instruction("LOAD_CLOSURE", argval=name)
- def create_store(self, name) -> Instruction:
- if name in self.cell_and_freevars():
- return create_instruction("STORE_DEREF", argval=name)
- assert name in self.code_options["co_varnames"]
- return create_instruction("STORE_FAST", argval=name)
- def create_load_global(self, name, push_null, add=False) -> Instruction:
- if add:
- self.tx.output.update_co_names(name)
- assert name in self.code_options["co_names"], f"{name} not in co_names"
- return create_load_global(name, push_null)
- def create_load_const(self, value) -> Instruction:
- assert is_safe_constant(value), f"unsafe constant {value}"
- return self._create_load_const(value)
- def _create_load_const(self, value) -> Instruction:
- return create_instruction("LOAD_CONST", argval=value)
- create_load_output = _create_load_const
- def create_load_method(self, name):
- self.tx.output.update_co_names(name)
- return create_load_method(name)
- def load_method(self, name):
- self.append_output(self.create_load_method(name))
- def call_method(self, nargs):
- self.extend_output(create_call_method(nargs))
- def create_load_attr(self, name) -> Instruction:
- if name not in self.code_options["co_names"]:
- self.code_options["co_names"] += (name,)
- return create_load_attr(name)
- def load_attr(self, name):
- self.append_output(self.create_load_attr(name))
- def create_load_attrs(self, names):
- return [self.create_load_attr(name) for name in names.split(".")]
- def create_store_attr(self, name) -> Instruction:
- if name not in self.code_options["co_names"]:
- self.code_options["co_names"] += (name,)
- return create_instruction("STORE_ATTR", argval=name)
- def store_attr(self, name):
- self.append_output(self.create_store_attr(name))
- def load_function_name(self, fn_name, push_null, num_on_stack=0):
- """Load the global fn_name on the stack num_on_stack down"""
- output = []
- if push_null and sys.version_info >= (3, 11):
- output.extend(
- [create_instruction("PUSH_NULL"), *self.rot_n(num_on_stack + 1)]
- )
- output.extend(
- [
- self.create_load_global(fn_name, False, add=True),
- *self.rot_n(num_on_stack + 1),
- ]
- )
- return output
- def rot_n(self, n):
- try:
- return create_rot_n(n)
- except AttributeError:
- # desired rotate bytecode doesn't exist, generate equivalent bytecode
- return [
- create_instruction("BUILD_TUPLE", arg=n),
- self._create_load_const(rot_n_helper(n)),
- *create_rot_n(2),
- create_instruction("CALL_FUNCTION_EX", arg=0),
- create_instruction("UNPACK_SEQUENCE", arg=n),
- ]
- def pop_null(self):
- # POP_TOP doesn't work for null, so we pop nulls by pushing in a
- # nop function, calling it (which consumes the null), and popping the result.
- assert sys.version_info >= (3, 11)
- return [
- self._create_load_const(lambda: None),
- *create_call_function(0, False),
- create_instruction("POP_TOP"),
- ]
- def pop_top(self):
- self.append_output(create_instruction("POP_TOP"))
- def call_function(self, nargs: int, push_null: bool):
- self.extend_output(create_call_function(nargs, push_null=push_null))
- def dup_top(self):
- self.append_output(create_dup_top())
- def store(self, varname):
- self.append_output(self.create_store(varname))
- def make_function_with_closure(
- self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
- ):
- freevars = code.co_freevars
- assert freevars
- output = self._output
- if sys.version_info >= (3, 11) and push_null:
- output.append(create_instruction("PUSH_NULL"))
- output.extend(self.rot_n(num_on_stack + 1))
- for var in freevars:
- assert var in self.cell_and_freevars()
- output.append(create_instruction("LOAD_CLOSURE", argval=var))
- output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
- output.append(self.create_load_const(code))
- if sys.version_info < (3, 11):
- output.append(self.create_load_const(fn_name))
- output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
- output.extend(self.rot_n(num_on_stack + 1))
- self.clear_tos()
- def create_load_python_module(self, mod, push_null) -> Instruction:
- """
- Generate a LOAD_GLOBAL instruction to fetch a given python module.
- """
- output = self.tx.output
- global_scope = output.global_scope
- name = re.sub(r"^.*[.]", "", mod.__name__)
- if global_scope.get(name, None) is mod:
- return self.create_load_global(name, push_null, add=True)
- prefix = f"___module_{name}"
- global_name = self.tx.output.install_global_by_id(prefix, mod)
- return self.create_load_global(global_name, push_null, add=True)
- def make_call_generated_code(self, fn_name: str) -> None:
- """Call the generated code function stored in fn_name"""
- self.extend_output(self.load_function_name(fn_name, True))
- graphargs = self.tx.output.graphargs
- for arg in graphargs:
- if arg.pass_arg_as_tensor:
- self.extend_output(
- [
- self.create_load_python_module(torch, True),
- self.create_load_attr("as_tensor"),
- ]
- )
- self.call_reconstruct(arg)
- self.extend_output(create_call_function(1, False))
- else:
- self.call_reconstruct(arg)
- self.extend_output(create_call_function(len(graphargs), False))
- def load_import_from(self, module_name, object_name) -> None:
- self(AttrSource(self.tx.import_source(module_name), object_name))
- def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]:
- if sys.version_info >= (3, 11):
- output = create_call_function(nargs, push_null)
- if sys.version_info >= (3, 12):
- idx = -1
- expected_inst = "CALL"
- else:
- idx = -2
- expected_inst = "PRECALL"
- assert output[idx].opname == expected_inst
- kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
- output.insert(idx, kw_names_inst)
- return output
- return [
- self.create_load_const(kw_names),
- create_instruction("CALL_FUNCTION_KW", arg=nargs),
- ]
- def create_delete(self, value) -> Instruction:
- return create_instruction("DELETE_FAST", argval=value)
|