codegen.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import dataclasses
  4. import re
  5. import sys
  6. import types
  7. from typing import Counter, Dict, List, Optional
  8. import torch.nn
  9. from . import utils
  10. from .bytecode_transformation import (
  11. create_call_function,
  12. create_call_method,
  13. create_dup_top,
  14. create_instruction,
  15. create_load_attr,
  16. create_load_global,
  17. create_load_method,
  18. create_rot_n,
  19. Instruction,
  20. )
  21. from .exc import unimplemented
  22. from .source import AttrSource, Source
  23. from .utils import is_safe_constant, rot_n_helper
  24. from .variables.base import VariableTracker
  25. from .variables.nn_module import NNModuleVariable
  26. from .variables.tensor import (
  27. NumpyNdarrayVariable,
  28. SymNodeVariable,
  29. TensorVariable,
  30. UnspecializedPythonVariable,
  31. )
  32. from .variables.torch_function import TensorWithTFOverrideVariable
  33. @dataclasses.dataclass
  34. class GraphOutputEntry:
  35. index: int
  36. variable: VariableTracker
  37. class PyCodegen:
  38. """
  39. Helper class uses for constructing Python bytecode
  40. """
  41. def __init__(
  42. self,
  43. tx=None,
  44. root: Optional[torch.nn.Module] = None,
  45. graph_output_var: Optional[str] = None,
  46. tempvars=None,
  47. ):
  48. self.root = root
  49. self.top_of_stack: Optional[VariableTracker] = None
  50. self.uses: Counter[VariableTracker] = collections.Counter()
  51. self.graph_outputs: Dict[int, GraphOutputEntry] = {}
  52. self._output: List[Instruction] = []
  53. self.tempvars = tempvars or {}
  54. self.tx = tx
  55. self.graph_output_var = graph_output_var
  56. self.code_options = self.tx.output.code_options
  57. self.cell_and_freevars = self.tx.cell_and_freevars
  58. self.new_var = self.tx.output.new_var
  59. self.mutable_side_effects_from_source = False
  60. self.value_from_source: bool = True
  61. def restore_stack(self, stack_values, *, value_from_source=True):
  62. prior = self.mutable_side_effects_from_source
  63. self.mutable_side_effects_from_source = True
  64. prev = self.value_from_source
  65. self.value_from_source &= value_from_source
  66. try:
  67. self.foreach(stack_values)
  68. finally:
  69. self.mutable_side_effects_from_source = prior
  70. self.value_from_source = prev
  71. def graph_output_vars(self):
  72. return [x.variable for x in self.graph_outputs.values()]
  73. def call_reconstruct(self, value):
  74. res = value.reconstruct(self)
  75. assert res is None, f"reconstruct!=None {value}"
  76. def __call__(self, value, allow_cache=True):
  77. """Generate code such that top-of-stack (TOS) is set to value"""
  78. if isinstance(value, Source):
  79. self.call_reconstruct(value)
  80. self.clear_tos()
  81. return
  82. assert isinstance(value, VariableTracker)
  83. output = self._output
  84. graph_outputs = self.graph_outputs
  85. if self.top_of_stack is value and allow_cache:
  86. output.append(create_dup_top())
  87. return
  88. if self.mutable_side_effects_from_source:
  89. # this is needed to get aliasing relationships right
  90. # value.mutable_local.source will get mutated to hold `value`
  91. # mutable_side_effects_from_source=False is used to codegen the mutation
  92. # mutable_side_effects_from_source=True is used to codegen a reference
  93. from .side_effects import MutableSideEffects
  94. if isinstance(value.mutable_local, MutableSideEffects):
  95. self(value.mutable_local.source)
  96. return
  97. if allow_cache:
  98. if value.mutable_local and value.mutable_local in self.tempvars:
  99. output.append(self.create_load(self.tempvars[value.mutable_local]))
  100. self.top_of_stack = value
  101. return
  102. if self.tempvars.get(value) is not None:
  103. output.append(self.create_load(self.tempvars[value]))
  104. self.top_of_stack = value
  105. return
  106. if value.source is not None and allow_cache and self.value_from_source:
  107. self.call_reconstruct(value.source)
  108. elif value.is_python_constant() and is_safe_constant(
  109. value.as_python_constant()
  110. ):
  111. output.append(self.create_load_const(value.as_python_constant()))
  112. elif isinstance(value, TensorWithTFOverrideVariable):
  113. graph_outputs_key = self.add_graph_output(value)
  114. self.load_import_from(utils.__name__, "to_subclass")
  115. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  116. output.append(
  117. self.create_load_global(
  118. value.global_mangled_class_name(self.tx), False, add=True
  119. )
  120. )
  121. output.extend(create_call_function(2, True))
  122. elif (
  123. isinstance(value, SymNodeVariable)
  124. and value.python_type() == float
  125. and not self.tx.export
  126. ):
  127. # This is a little unusual; force the output convention to be a
  128. # Tensor here. Don't do this for export because this is
  129. # apparently load bearing for export tests (but I am a bit
  130. # doubtful it actually works in the real world)
  131. # NB: It works to add_graph_output on a computed expression
  132. # as_tensor here, because we memoize as_tensor calls on
  133. # SymNodeVariable!
  134. graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx))
  135. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  136. output.extend(
  137. [self.create_load_attr("item")] + create_call_function(0, True)
  138. )
  139. elif isinstance(
  140. value,
  141. (
  142. TensorVariable,
  143. SymNodeVariable,
  144. UnspecializedPythonVariable,
  145. NumpyNdarrayVariable,
  146. ),
  147. ):
  148. graph_outputs_key = self.add_graph_output(value)
  149. if isinstance(value, NumpyNdarrayVariable):
  150. self.load_import_from(utils.__name__, "to_numpy_helper")
  151. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  152. if isinstance(value, NumpyNdarrayVariable):
  153. output.extend(create_call_function(1, True))
  154. elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
  155. output.extend(
  156. [self.create_load_attr("item")] + create_call_function(0, True)
  157. )
  158. elif isinstance(value, NNModuleVariable):
  159. parts = value.module_key.split(".")
  160. if parts[0] in self.code_options["co_varnames"]:
  161. output.append(self.create_load(parts[0]))
  162. parts = parts[1:]
  163. else:
  164. assert self.root is not None
  165. output.append(self.create_load_output(self.root))
  166. for part in parts:
  167. output.append(self.create_load_attr(part))
  168. else:
  169. self.uses[value] += 1
  170. try:
  171. self.call_reconstruct(value)
  172. except NotImplementedError:
  173. unimplemented(f"reconstruct: {value}")
  174. if allow_cache and value in self.tempvars:
  175. self._output.append(create_dup_top())
  176. self.add_cache(value)
  177. self.top_of_stack = value
  178. def add_graph_output(self, value):
  179. graph_outputs_key = id(value.as_proxy())
  180. if graph_outputs_key not in self.graph_outputs:
  181. self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
  182. len(self.graph_outputs), value
  183. )
  184. return graph_outputs_key
  185. def load_graph_output(self, index):
  186. output = self._output
  187. output.append(self.create_load(self.graph_output_var))
  188. output.append(self._create_load_const(index))
  189. output.append(create_instruction("BINARY_SUBSCR"))
  190. def add_cache(self, value):
  191. var = self.new_var()
  192. self.tempvars[value] = var
  193. if value.mutable_local:
  194. self.tempvars[value.mutable_local] = var
  195. self._output.append(self.create_store(var))
  196. def foreach(self, items):
  197. for i in items:
  198. self(i)
  199. def setup_globally_cached(self, name, value, push_null):
  200. """Store value in a new global"""
  201. name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  202. f_globals = self.tx.f_globals
  203. if name in f_globals:
  204. assert id(f_globals[name]) == id(value)
  205. else:
  206. f_globals[name] = value
  207. return [self.create_load_global(name, push_null, add=True)]
  208. def clear_tos(self):
  209. self.top_of_stack = None
  210. def append_output(self, inst):
  211. assert isinstance(inst, Instruction)
  212. self._output.append(inst)
  213. self.clear_tos()
  214. def extend_output(self, insts):
  215. assert all(isinstance(x, Instruction) for x in insts)
  216. self._output.extend(insts)
  217. self.clear_tos()
  218. def get_instructions(self) -> List[Instruction]:
  219. return self._output
  220. def create_load(self, name) -> Instruction:
  221. if name in self.cell_and_freevars():
  222. return create_instruction("LOAD_DEREF", argval=name)
  223. assert name in self.code_options["co_varnames"], f"{name} missing"
  224. return create_instruction("LOAD_FAST", argval=name)
  225. def create_load_closure(self, name) -> Instruction:
  226. assert name in self.cell_and_freevars()
  227. return create_instruction("LOAD_CLOSURE", argval=name)
  228. def create_store(self, name) -> Instruction:
  229. if name in self.cell_and_freevars():
  230. return create_instruction("STORE_DEREF", argval=name)
  231. assert name in self.code_options["co_varnames"]
  232. return create_instruction("STORE_FAST", argval=name)
  233. def create_load_global(self, name, push_null, add=False) -> Instruction:
  234. if add:
  235. self.tx.output.update_co_names(name)
  236. assert name in self.code_options["co_names"], f"{name} not in co_names"
  237. return create_load_global(name, push_null)
  238. def create_load_const(self, value) -> Instruction:
  239. assert is_safe_constant(value), f"unsafe constant {value}"
  240. return self._create_load_const(value)
  241. def _create_load_const(self, value) -> Instruction:
  242. return create_instruction("LOAD_CONST", argval=value)
  243. create_load_output = _create_load_const
  244. def create_load_method(self, name):
  245. self.tx.output.update_co_names(name)
  246. return create_load_method(name)
  247. def load_method(self, name):
  248. self.append_output(self.create_load_method(name))
  249. def call_method(self, nargs):
  250. self.extend_output(create_call_method(nargs))
  251. def create_load_attr(self, name) -> Instruction:
  252. if name not in self.code_options["co_names"]:
  253. self.code_options["co_names"] += (name,)
  254. return create_load_attr(name)
  255. def load_attr(self, name):
  256. self.append_output(self.create_load_attr(name))
  257. def create_load_attrs(self, names):
  258. return [self.create_load_attr(name) for name in names.split(".")]
  259. def create_store_attr(self, name) -> Instruction:
  260. if name not in self.code_options["co_names"]:
  261. self.code_options["co_names"] += (name,)
  262. return create_instruction("STORE_ATTR", argval=name)
  263. def store_attr(self, name):
  264. self.append_output(self.create_store_attr(name))
  265. def load_function_name(self, fn_name, push_null, num_on_stack=0):
  266. """Load the global fn_name on the stack num_on_stack down"""
  267. output = []
  268. if push_null and sys.version_info >= (3, 11):
  269. output.extend(
  270. [create_instruction("PUSH_NULL"), *self.rot_n(num_on_stack + 1)]
  271. )
  272. output.extend(
  273. [
  274. self.create_load_global(fn_name, False, add=True),
  275. *self.rot_n(num_on_stack + 1),
  276. ]
  277. )
  278. return output
  279. def rot_n(self, n):
  280. try:
  281. return create_rot_n(n)
  282. except AttributeError:
  283. # desired rotate bytecode doesn't exist, generate equivalent bytecode
  284. return [
  285. create_instruction("BUILD_TUPLE", arg=n),
  286. self._create_load_const(rot_n_helper(n)),
  287. *create_rot_n(2),
  288. create_instruction("CALL_FUNCTION_EX", arg=0),
  289. create_instruction("UNPACK_SEQUENCE", arg=n),
  290. ]
  291. def pop_null(self):
  292. # POP_TOP doesn't work for null, so we pop nulls by pushing in a
  293. # nop function, calling it (which consumes the null), and popping the result.
  294. assert sys.version_info >= (3, 11)
  295. return [
  296. self._create_load_const(lambda: None),
  297. *create_call_function(0, False),
  298. create_instruction("POP_TOP"),
  299. ]
  300. def pop_top(self):
  301. self.append_output(create_instruction("POP_TOP"))
  302. def call_function(self, nargs: int, push_null: bool):
  303. self.extend_output(create_call_function(nargs, push_null=push_null))
  304. def dup_top(self):
  305. self.append_output(create_dup_top())
  306. def store(self, varname):
  307. self.append_output(self.create_store(varname))
  308. def make_function_with_closure(
  309. self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
  310. ):
  311. freevars = code.co_freevars
  312. assert freevars
  313. output = self._output
  314. if sys.version_info >= (3, 11) and push_null:
  315. output.append(create_instruction("PUSH_NULL"))
  316. output.extend(self.rot_n(num_on_stack + 1))
  317. for var in freevars:
  318. assert var in self.cell_and_freevars()
  319. output.append(create_instruction("LOAD_CLOSURE", argval=var))
  320. output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
  321. output.append(self.create_load_const(code))
  322. if sys.version_info < (3, 11):
  323. output.append(self.create_load_const(fn_name))
  324. output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
  325. output.extend(self.rot_n(num_on_stack + 1))
  326. self.clear_tos()
  327. def create_load_python_module(self, mod, push_null) -> Instruction:
  328. """
  329. Generate a LOAD_GLOBAL instruction to fetch a given python module.
  330. """
  331. output = self.tx.output
  332. global_scope = output.global_scope
  333. name = re.sub(r"^.*[.]", "", mod.__name__)
  334. if global_scope.get(name, None) is mod:
  335. return self.create_load_global(name, push_null, add=True)
  336. prefix = f"___module_{name}"
  337. global_name = self.tx.output.install_global_by_id(prefix, mod)
  338. return self.create_load_global(global_name, push_null, add=True)
  339. def make_call_generated_code(self, fn_name: str) -> None:
  340. """Call the generated code function stored in fn_name"""
  341. self.extend_output(self.load_function_name(fn_name, True))
  342. graphargs = self.tx.output.graphargs
  343. for arg in graphargs:
  344. if arg.pass_arg_as_tensor:
  345. self.extend_output(
  346. [
  347. self.create_load_python_module(torch, True),
  348. self.create_load_attr("as_tensor"),
  349. ]
  350. )
  351. self.call_reconstruct(arg)
  352. self.extend_output(create_call_function(1, False))
  353. else:
  354. self.call_reconstruct(arg)
  355. self.extend_output(create_call_function(len(graphargs), False))
  356. def load_import_from(self, module_name, object_name) -> None:
  357. self(AttrSource(self.tx.import_source(module_name), object_name))
  358. def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]:
  359. if sys.version_info >= (3, 11):
  360. output = create_call_function(nargs, push_null)
  361. if sys.version_info >= (3, 12):
  362. idx = -1
  363. expected_inst = "CALL"
  364. else:
  365. idx = -2
  366. expected_inst = "PRECALL"
  367. assert output[idx].opname == expected_inst
  368. kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
  369. output.insert(idx, kw_names_inst)
  370. return output
  371. return [
  372. self.create_load_const(kw_names),
  373. create_instruction("CALL_FUNCTION_KW", arg=nargs),
  374. ]
  375. def create_delete(self, value) -> Instruction:
  376. return create_instruction("DELETE_FAST", argval=value)