exported_program.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import dataclasses
  4. import functools
  5. import re
  6. import types
  7. import warnings
  8. from collections import namedtuple
  9. from typing import (
  10. Any,
  11. Callable,
  12. Dict,
  13. Iterator,
  14. List,
  15. Optional,
  16. Tuple,
  17. Type,
  18. TYPE_CHECKING,
  19. Union,
  20. )
  21. from torch.fx.immutable_collections import immutable_dict, immutable_list
  22. if TYPE_CHECKING:
  23. # Import the following modules during type checking to enable code intelligence features,
  24. # such as auto-completion in tools like pylance, even when these modules are not explicitly
  25. # imported in user code.
  26. import sympy
  27. from torch.utils._sympy.value_ranges import ValueRanges
  28. import torch
  29. import torch.utils._pytree as pytree
  30. from torch.export._tree_utils import is_equivalent, reorder_kwargs
  31. from torch.fx._compatibility import compatibility
  32. from torch.fx._utils import first_call_function_nn_module_stack
  33. from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
  34. from torch.fx.passes.infra.pass_base import PassResult
  35. from torch.fx.passes.infra.pass_manager import PassManager
  36. from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
  37. from .graph_signature import ( # noqa: F401
  38. _sig_to_specs,
  39. ArgumentSpec,
  40. ConstantArgument,
  41. CustomObjArgument,
  42. ExportGraphSignature,
  43. InputKind,
  44. InputSpec,
  45. OutputKind,
  46. OutputSpec,
  47. SymIntArgument,
  48. TensorArgument,
  49. TokenArgument,
  50. )
  51. __all__ = [
  52. "ExportedProgram",
  53. "ModuleCallEntry",
  54. "ModuleCallSignature",
  55. ]
  56. PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
  57. @dataclasses.dataclass
  58. class ModuleCallSignature:
  59. inputs: List[ArgumentSpec]
  60. outputs: List[ArgumentSpec]
  61. in_spec: pytree.TreeSpec
  62. out_spec: pytree.TreeSpec
  63. @dataclasses.dataclass
  64. class ModuleCallEntry:
  65. fqn: str
  66. signature: Optional[ModuleCallSignature] = None
  67. def _disable_prexisiting_fake_mode(fn):
  68. @functools.wraps(fn)
  69. def wrapper(*args, **kwargs):
  70. with maybe_disable_fake_tensor_mode():
  71. return fn(*args, **kwargs)
  72. return wrapper
  73. def _fx_collection_equivalence_fn(
  74. spec1_type: Optional[type],
  75. spec1_context: pytree.Context,
  76. spec2_type: Optional[type],
  77. spec2_context: pytree.Context,
  78. ) -> bool:
  79. """Treat containers and their immutable variants as the same type. Otherwise
  80. compare as normal.
  81. """
  82. if spec1_type is None or spec2_type is None:
  83. return spec1_type is spec2_type and spec1_context == spec2_context
  84. if issubclass(spec1_type, (dict, immutable_dict)) and issubclass(
  85. spec2_type, (dict, immutable_dict)
  86. ):
  87. return spec1_context == spec2_context
  88. if issubclass(spec1_type, (list, immutable_list)) and issubclass(
  89. spec2_type, (list, immutable_list)
  90. ):
  91. return spec1_context == spec2_context
  92. return spec1_type is spec2_type and spec1_context == spec2_context
  93. def _rename_without_collisions(
  94. name_map: Dict[str, str],
  95. orig_name: str,
  96. name: str,
  97. is_placeholder: bool = False,
  98. ):
  99. """
  100. Renames nodes to avoid name collisions, with suffixing.
  101. name_map: map from original name to new name
  102. orig_name: mapping key
  103. name: candidate name (potentially suffixed, e.g. mul_2)
  104. is_placeholder: if the node is a placeholder, avoid detecting suffix
  105. """
  106. if name in name_map.values():
  107. # non-placeholder nodes may be suffixed with the count
  108. # instead of adding another suffix, we will try to increment it
  109. match = re.match(r"(.*)_(\d+)", name)
  110. if match and not is_placeholder:
  111. name, n = match.group(1), int(match.group(2))
  112. else:
  113. n = 0
  114. while (dup_name := f"{name}_{n + 1}") in name_map.values():
  115. n += 1
  116. name_map[orig_name] = dup_name
  117. else:
  118. name_map[orig_name] = name
  119. return name_map[orig_name]
  120. def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
  121. """
  122. Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs,
  123. and handle collisions with non-placeholders by count suffixing.
  124. Different HOO subgraph types have different input schemas, so we first enumerate them
  125. and gather the top-level named placeholder nodes.
  126. """
  127. # gather all HOO subgraphs and their top-level named placeholder nodes
  128. subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
  129. for node in gm.graph.nodes:
  130. if node.op == "call_function" and isinstance(
  131. node.target, torch._ops.HigherOrderOperator
  132. ):
  133. # HOO subgraphs have varying input schemas, so we enumerate them there
  134. if node.target._name == "cond":
  135. _, true_graph, false_graph, cond_args = node._args
  136. subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args))
  137. subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args))
  138. elif node.target._name == "wrap_with_set_grad_enabled":
  139. subgraph, phs = node._args[1], node._args[2:]
  140. subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs))
  141. elif node.target._name == "map_impl":
  142. body_graph, array, args = node._args
  143. subgraph_ph_tuples.append(
  144. (getattr(gm, body_graph.target), array + args)
  145. )
  146. # propagate names
  147. for subgraph, hoo_phs in subgraph_ph_tuples:
  148. name_map: Dict[str, str] = {}
  149. for i, node in enumerate(subgraph.graph.nodes):
  150. if i < len(hoo_phs): # placeholder, retain name
  151. name_map[node.name] = hoo_phs[i].name
  152. node.name = node.target = hoo_phs[i].name
  153. else: # non-placeholder, check for collisions
  154. node.name = _rename_without_collisions(name_map, node.name, node.name)
  155. # recurse and recompile
  156. _name_hoo_subgraph_placeholders(subgraph)
  157. subgraph.recompile()
  158. class ExportedProgram:
  159. """
  160. Package of a program from :func:`export`. It contains
  161. an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
  162. tensor values of all lifted parameters and buffers, and various metadata.
  163. You can call an ExportedProgram like the original callable traced by
  164. :func:`export` with the same calling convention.
  165. To perform transformations on the graph, use ``.module`` property to access
  166. an :class:`torch.fx.GraphModule`. You can then use
  167. `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
  168. to rewrite the graph. Afterwards, you can simply use :func:`export`
  169. again to construct a correct ExportedProgram.
  170. """
  171. def __init__(
  172. self,
  173. root: Union[torch.nn.Module, Dict[str, Any]],
  174. graph: torch.fx.Graph,
  175. graph_signature: ExportGraphSignature,
  176. state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
  177. range_constraints: "Dict[sympy.Symbol, Any]",
  178. module_call_graph: List[ModuleCallEntry],
  179. example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
  180. verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier.
  181. tensor_constants: Optional[
  182. Dict[str, torch.Tensor]
  183. ] = None, # TODO: deprecate this
  184. constants: Optional[
  185. Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]
  186. ] = None,
  187. ):
  188. # Remove codegen related things from the graph. It should just be a flat graph.
  189. graph._codegen = torch.fx.graph.CodeGen()
  190. self._graph_module = _create_graph_module_for_export(root, graph)
  191. if isinstance(root, torch.fx.GraphModule):
  192. self._graph_module.meta.update(root.meta)
  193. self._graph_signature: ExportGraphSignature = graph_signature
  194. self._state_dict: Dict[str, Any] = state_dict
  195. self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
  196. assert module_call_graph is not None
  197. self._module_call_graph: List[ModuleCallEntry] = module_call_graph
  198. self._example_inputs = example_inputs
  199. self._constants = tensor_constants or constants or {}
  200. assert self._constants is not None
  201. from torch._export.verifier import Verifier
  202. if verifier is None:
  203. verifier = Verifier
  204. assert issubclass(verifier, Verifier)
  205. self._verifier = verifier
  206. # Validate should be always the last step of the constructor.
  207. self.verifier().check(self)
  208. @property
  209. @compatibility(is_backward_compatible=False)
  210. def graph_module(self):
  211. return self._graph_module
  212. @property
  213. @compatibility(is_backward_compatible=False)
  214. def graph(self):
  215. return self.graph_module.graph
  216. @property
  217. @compatibility(is_backward_compatible=False)
  218. def graph_signature(self):
  219. return self._graph_signature
  220. @property
  221. @compatibility(is_backward_compatible=False)
  222. def state_dict(self):
  223. return self._state_dict
  224. @compatibility(is_backward_compatible=False)
  225. def parameters(self) -> Iterator[torch.nn.Parameter]:
  226. """
  227. Returns an iterator over original module's parameters.
  228. """
  229. for _, param in self.named_parameters():
  230. yield param
  231. @compatibility(is_backward_compatible=False)
  232. def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
  233. """
  234. Returns an iterator over original module parameters, yielding
  235. both the name of the parameter as well as the parameter itself.
  236. """
  237. for param_name in self.graph_signature.parameters:
  238. yield param_name, self.state_dict[param_name]
  239. @compatibility(is_backward_compatible=False)
  240. def buffers(self) -> Iterator[torch.Tensor]:
  241. """
  242. Returns an iterator over original module buffers.
  243. """
  244. for _, buf in self.named_buffers():
  245. yield buf
  246. @compatibility(is_backward_compatible=False)
  247. def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
  248. """
  249. Returns an iterator over original module buffers, yielding
  250. both the name of the buffer as well as the buffer itself.
  251. """
  252. non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
  253. for buffer_name in self.graph_signature.buffers:
  254. if buffer_name in non_persistent_buffers:
  255. yield buffer_name, self.constants[buffer_name]
  256. else:
  257. yield buffer_name, self.state_dict[buffer_name]
  258. @property
  259. @compatibility(is_backward_compatible=False)
  260. def range_constraints(self):
  261. return self._range_constraints
  262. @property
  263. @compatibility(is_backward_compatible=False)
  264. def module_call_graph(self):
  265. return self._module_call_graph
  266. @property
  267. @compatibility(is_backward_compatible=False)
  268. def example_inputs(self):
  269. return self._example_inputs
  270. @property
  271. @compatibility(is_backward_compatible=False)
  272. def call_spec(self):
  273. CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
  274. if len(self.module_call_graph) == 0:
  275. return CallSpec(in_spec=None, out_spec=None)
  276. assert self.module_call_graph[0].fqn == ""
  277. return CallSpec(
  278. in_spec=self.module_call_graph[0].signature.in_spec,
  279. out_spec=self.module_call_graph[0].signature.out_spec,
  280. )
  281. @property
  282. @compatibility(is_backward_compatible=False)
  283. def verifier(self) -> Any:
  284. return self._verifier
  285. @property
  286. @compatibility(is_backward_compatible=False)
  287. def dialect(self) -> str:
  288. return self._verifier.dialect
  289. @property
  290. @compatibility(is_backward_compatible=False)
  291. def tensor_constants(self):
  292. return self._constants
  293. @property
  294. @compatibility(is_backward_compatible=False)
  295. def constants(self):
  296. return self._constants
  297. def _get_flat_args_with_check(self, args, kwargs):
  298. """Flatten args, kwargs using pytree, then, check specs.
  299. Args:
  300. args: List[Any] original args passed to __call__
  301. kwargs: Dict[str, Any] original kwargs passed to __call
  302. Returns:
  303. A tuple of (flat_args, received_spec)
  304. flat_args is flattend args / kwargs
  305. received_spec is the pytree spec produced while flattening the
  306. tuple (args, kwargs)
  307. """
  308. in_spec = self.call_spec.in_spec
  309. if in_spec is not None:
  310. kwargs = reorder_kwargs(kwargs, in_spec)
  311. flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
  312. (args, kwargs)
  313. ) # type: ignore[possibly-undefined]
  314. self._check_input_constraints(flat_args_with_path)
  315. flat_args = tuple(x[1] for x in flat_args_with_path)
  316. return flat_args, received_spec
  317. def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any:
  318. """Transform args, kwargs of __call__ to args for graph_module.
  319. self.graph_module takes stuff from state dict as inputs.
  320. The invariant is for ep: ExportedProgram is
  321. ep(args, kwargs) ==
  322. ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
  323. """
  324. in_spec = self.call_spec.in_spec
  325. flat_args, received_spec = self._get_flat_args_with_check(args, kwargs)
  326. if in_spec is not None and not is_equivalent(
  327. received_spec, in_spec, _fx_collection_equivalence_fn
  328. ):
  329. raise ValueError(
  330. "Trying to flatten user inputs with exported input tree spec: \n"
  331. f"{in_spec}\n"
  332. "but actually got inputs with tree spec of: \n"
  333. f"{received_spec}"
  334. )
  335. additional_inputs = []
  336. for input_ in self.graph_signature.input_specs:
  337. if input_.kind == InputKind.USER_INPUT:
  338. continue
  339. elif input_.kind in (
  340. InputKind.PARAMETER,
  341. InputKind.BUFFER,
  342. ):
  343. if input_.persistent is False:
  344. # This is a non-persistent buffer, grab it from our
  345. # constants instead of the state dict.
  346. additional_inputs.append(self.constants[input_.target])
  347. else:
  348. additional_inputs.append(self.state_dict[input_.target])
  349. elif input_.kind in (
  350. InputKind.CONSTANT_TENSOR,
  351. InputKind.CUSTOM_OBJ,
  352. ):
  353. additional_inputs.append(self.constants[input_.target])
  354. additional_inputs = tuple(additional_inputs)
  355. # NOTE: calling convention is first params, then buffers, then args as user supplied them.
  356. # See: torch/_functorch/aot_autograd.py#L1034
  357. return additional_inputs + flat_args
  358. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  359. raise RuntimeError(
  360. "Unable to call ExportedProgram directly. "
  361. "You should use `exported_program.module()` instead."
  362. )
  363. def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs):
  364. """Process potential mutations to the input.
  365. Because self.graph_module is functional, so mutations has to be written
  366. back after execution of graph_module.
  367. """
  368. import torch._export.error as error
  369. flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs)
  370. if self.call_spec.out_spec is not None:
  371. buffer_mutation = self.graph_signature.buffers_to_mutate
  372. user_input_mutation = self.graph_signature.user_inputs_to_mutate
  373. num_mutated = len(buffer_mutation) + len(user_input_mutation)
  374. mutated_values = res[:num_mutated]
  375. # Exclude dependency token from final result.
  376. assertion_dep_token = self.graph_signature.assertion_dep_token
  377. if assertion_dep_token is not None:
  378. assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
  379. res = res[:assertion_dep_token_index]
  380. res = res[num_mutated:]
  381. try:
  382. res = pytree.tree_unflatten(res, self.call_spec.out_spec)
  383. except Exception:
  384. _, received_spec = pytree.tree_flatten(res)
  385. raise error.InternalError( # noqa: B904
  386. "Trying to flatten user outputs with exported output tree spec: \n"
  387. f"{self.call_spec.out_spec}\n"
  388. "but actually got outputs with tree spec of: \n"
  389. f"{received_spec}"
  390. )
  391. finally:
  392. user_inputs = [
  393. spec
  394. for spec in self.graph_signature.input_specs
  395. if spec.kind == InputKind.USER_INPUT
  396. ]
  397. for i, value in enumerate(mutated_values):
  398. output_spec = self.graph_signature.output_specs[i]
  399. if output_spec.kind == OutputKind.BUFFER_MUTATION:
  400. assert output_spec.target is not None
  401. self.state_dict[output_spec.target] = value
  402. elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
  403. assert output_spec.target is not None
  404. index = next(
  405. i
  406. for i, spec in enumerate(user_inputs)
  407. if spec.arg.name == output_spec.target
  408. )
  409. flat_args[index].copy_(value)
  410. else:
  411. raise AssertionError(f"Unexpected kind: {output_spec.kind}")
  412. return res
  413. def __str__(self) -> str:
  414. graph_module = self.graph_module.print_readable(print_output=False).replace(
  415. "\n", "\n "
  416. )
  417. string = (
  418. "ExportedProgram:\n"
  419. f" {graph_module}\n"
  420. f"Graph signature: {self.graph_signature}\n"
  421. f"Range constraints: {self.range_constraints}\n"
  422. )
  423. return string
  424. def module(self) -> torch.nn.Module:
  425. """
  426. Returns a self contained GraphModule with all the parameters/buffers inlined.
  427. """
  428. from ._unlift import _unlift_exported_program_lifted_states
  429. module = _unlift_exported_program_lifted_states(self)
  430. def _train(self, mode: bool = True):
  431. raise NotImplementedError("Calling train() is not supported yet.")
  432. def _eval(self, mode: bool = True):
  433. raise NotImplementedError("Calling eval() is not supported yet.")
  434. module.train = types.MethodType(_train, module) # type: ignore[method-assign]
  435. module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
  436. return module
  437. def _num_lifted_params_buffers(self):
  438. return next(
  439. (
  440. i
  441. for i, s in enumerate(self._graph_signature.input_specs)
  442. if s.kind == InputKind.USER_INPUT
  443. ),
  444. len(self._graph_signature.input_specs),
  445. )
  446. @_disable_prexisiting_fake_mode
  447. def run_decompositions(
  448. self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
  449. ) -> "ExportedProgram":
  450. """
  451. Run a set of decompositions on the exported program and returns a new
  452. exported program. By default we will run the Core ATen decompositions to
  453. get operators in the
  454. `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
  455. For now, we do not decompose joint graphs.
  456. """
  457. from torch._decomp import core_aten_decompositions
  458. from torch._export.passes.lift_constants_pass import (
  459. ConstantAttrMap,
  460. lift_constants_pass,
  461. )
  462. from torch._export.passes.replace_sym_size_ops_pass import (
  463. _replace_sym_size_ops_pass,
  464. )
  465. from torch._functorch.aot_autograd import aot_export_module
  466. def _get_placeholders(gm):
  467. placeholders = []
  468. for node in gm.graph.nodes:
  469. if node.op != "placeholder":
  470. break
  471. placeholders.append(node)
  472. return placeholders
  473. if decomp_table is None:
  474. decomp_table = core_aten_decompositions()
  475. old_placeholders = _get_placeholders(self.graph_module)
  476. fake_args = [node.meta["val"] for node in old_placeholders]
  477. buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
  478. for name in buffers_to_remove:
  479. delattr(self.graph_module, name)
  480. # TODO(zhxhchen17) Return the new graph_signature directly.
  481. from torch.export._trace import _ignore_backend_decomps
  482. with _ignore_backend_decomps():
  483. gm, graph_signature = aot_export_module(
  484. self.graph_module,
  485. fake_args,
  486. decompositions=decomp_table,
  487. trace_joint=False,
  488. )
  489. # Update the signatures with the new placeholder names in case they
  490. # changed when calling aot_export
  491. def update_arg(old_arg, new_ph):
  492. if isinstance(old_arg, ConstantArgument):
  493. return old_arg
  494. elif isinstance(old_arg, TensorArgument):
  495. return TensorArgument(name=new_ph.name)
  496. elif isinstance(old_arg, SymIntArgument):
  497. return SymIntArgument(name=new_ph.name)
  498. raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
  499. new_placeholders = _get_placeholders(gm)
  500. new_outputs = list(gm.graph.nodes)[-1].args[0]
  501. # rename the placeholders
  502. assert len(new_placeholders) == len(old_placeholders)
  503. for old_ph, new_ph in zip(old_placeholders, new_placeholders):
  504. new_ph.name = new_ph.target = old_ph.name
  505. # handle name collisions with newly decomposed graph nodes
  506. name_map = {ph.name: ph.name for ph in new_placeholders}
  507. for node in gm.graph.nodes:
  508. if node.op == "placeholder":
  509. continue
  510. node.name = _rename_without_collisions(name_map, node.name, node.name)
  511. # propagate names to higher order op subgraphs
  512. _name_hoo_subgraph_placeholders(gm)
  513. # To match the output target with correct input for input mutations
  514. # need to find the old to new placeholder map
  515. old_new_placeholder_map = {
  516. spec.arg.name: new_placeholders[i].name
  517. for i, spec in enumerate(self.graph_signature.input_specs)
  518. if not isinstance(spec.arg, ConstantArgument)
  519. }
  520. input_specs = [
  521. InputSpec(
  522. spec.kind,
  523. update_arg(spec.arg, new_placeholders[i]),
  524. spec.target,
  525. spec.persistent,
  526. )
  527. for i, spec in enumerate(self.graph_signature.input_specs)
  528. ]
  529. output_specs = [
  530. OutputSpec(
  531. spec.kind,
  532. update_arg(spec.arg, new_outputs[i]),
  533. old_new_placeholder_map.get(spec.target, spec.target),
  534. )
  535. for i, spec in enumerate(self.graph_signature.output_specs)
  536. ]
  537. assert len(new_placeholders) == len(old_placeholders)
  538. new_graph_signature = ExportGraphSignature(
  539. input_specs=input_specs, output_specs=output_specs
  540. )
  541. # NOTE: aot_export adds symint metadata for placeholders with int
  542. # values; since these become specialized, we replace such metadata with
  543. # the original values.
  544. # Also, set the param/buffer metadata back to the placeholders.
  545. for old_node, new_node in zip(old_placeholders, new_placeholders):
  546. if not isinstance(old_node.meta["val"], torch.Tensor):
  547. new_node.meta["val"] = old_node.meta["val"]
  548. if (
  549. new_node.target in new_graph_signature.inputs_to_parameters
  550. or new_node.target in new_graph_signature.inputs_to_buffers
  551. ):
  552. for k, v in old_node.meta.items():
  553. new_node.meta[k] = v
  554. # TODO unfortunately preserving graph-level metadata is not
  555. # working well with aot_export. So we manually copy it.
  556. # (The node-level meta is addressed above.)
  557. gm.meta.update(self.graph_module.meta)
  558. new_range_constraints = _get_updated_range_constraints(
  559. gm,
  560. self.range_constraints,
  561. _is_executorch=False,
  562. )
  563. constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap())
  564. for k, v in constants.items():
  565. assert k not in self.constants
  566. self.constants[k] = v
  567. _replace_sym_size_ops_pass(gm)
  568. from torch._dynamo import config as _dynamo_config
  569. from torch._export.passes._node_metadata_hook import (
  570. _node_metadata_hook,
  571. _set_node_metadata_hook,
  572. )
  573. if not _dynamo_config.do_not_emit_runtime_asserts:
  574. stack_trace = (
  575. 'File "torch/fx/passes/runtime_assert.py", line 24, '
  576. "in insert_deferred_runtime_asserts"
  577. )
  578. shape_env = _get_shape_env(gm)
  579. if shape_env is not None:
  580. with _set_node_metadata_hook(
  581. gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
  582. ):
  583. insert_deferred_runtime_asserts(
  584. gm,
  585. shape_env,
  586. f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
  587. export=True,
  588. )
  589. exported_program = ExportedProgram(
  590. root=gm,
  591. graph=gm.graph,
  592. graph_signature=new_graph_signature,
  593. state_dict=self.state_dict,
  594. range_constraints=new_range_constraints,
  595. module_call_graph=copy.deepcopy(self.module_call_graph),
  596. example_inputs=self.example_inputs,
  597. verifier=self.verifier,
  598. constants=self.constants,
  599. )
  600. return exported_program
  601. def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
  602. pm = PassManager(list(passes))
  603. # Since we abstractly run the passes, we need to disable backend decomp here
  604. # again.
  605. from torch.export._trace import _ignore_backend_decomps
  606. with _ignore_backend_decomps():
  607. res = pm(self.graph_module)
  608. transformed_gm = res.graph_module if res is not None else self.graph_module
  609. assert transformed_gm is not None
  610. if transformed_gm is self.graph_module and not res.modified:
  611. return self
  612. # TODO(zhxchen17) Remove this.
  613. def _get_updated_graph_signature(
  614. old_signature: ExportGraphSignature,
  615. new_gm: torch.fx.GraphModule,
  616. ) -> ExportGraphSignature:
  617. """
  618. Update the graph signature's user_input/user_outputs.
  619. """
  620. new_input_specs = []
  621. for i, node in enumerate(new_gm.graph.nodes):
  622. if node.op != "placeholder":
  623. break
  624. assert i < len(
  625. old_signature.input_specs
  626. ), "Number of inputs changed after transformation"
  627. old_input_spec = old_signature.input_specs[i]
  628. arg = (
  629. old_input_spec.arg
  630. if isinstance(
  631. old_input_spec.arg, (ConstantArgument, CustomObjArgument)
  632. )
  633. else type(old_input_spec.arg)(node.name)
  634. )
  635. new_input_specs.append(
  636. InputSpec(
  637. old_input_spec.kind,
  638. arg,
  639. old_input_spec.target,
  640. old_input_spec.persistent,
  641. )
  642. )
  643. output_node = list(new_gm.graph.nodes)[-1]
  644. assert output_node.op == "output"
  645. new_output_specs = []
  646. for i, node in enumerate(output_node.args[0]):
  647. assert i < len(
  648. old_signature.output_specs
  649. ), "Number of outputs changed after transformation"
  650. old_output_spec = old_signature.output_specs[i]
  651. arg = (
  652. old_output_spec.arg
  653. if isinstance(
  654. old_output_spec.arg, (ConstantArgument, CustomObjArgument)
  655. )
  656. else type(old_output_spec.arg)(node.name)
  657. )
  658. new_output_specs.append(
  659. OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
  660. )
  661. new_signature = ExportGraphSignature(
  662. input_specs=new_input_specs, output_specs=new_output_specs
  663. )
  664. return new_signature
  665. transformed_ep = ExportedProgram(
  666. root=transformed_gm,
  667. graph=transformed_gm.graph,
  668. graph_signature=_get_updated_graph_signature(
  669. self.graph_signature, transformed_gm
  670. ),
  671. state_dict=self.state_dict,
  672. range_constraints=_get_updated_range_constraints(
  673. transformed_gm,
  674. self.range_constraints,
  675. _is_executorch=False,
  676. ),
  677. module_call_graph=copy.deepcopy(self._module_call_graph),
  678. example_inputs=self.example_inputs,
  679. verifier=self.verifier,
  680. constants=self.constants,
  681. )
  682. transformed_ep.graph_module.meta.update(self.graph_module.meta)
  683. transformed_ep.graph_module.meta.update(res.graph_module.meta)
  684. return transformed_ep
  685. def _check_input_constraints(self, flat_args_with_path):
  686. from torch._export.utils import _check_input_constraints_for_graph
  687. placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
  688. input_placeholders = [
  689. p
  690. for p, s in zip(placeholders, self.graph_signature.input_specs)
  691. if s.kind == InputKind.USER_INPUT
  692. ]
  693. _check_input_constraints_for_graph(
  694. input_placeholders, flat_args_with_path, self.range_constraints
  695. )
  696. def _validate(self):
  697. self.verifier().check(self)
  698. # TODO(zhxchen17) Formalize this.
  699. def _update(
  700. self, graph_module, graph_signature, state_dict=None
  701. ) -> "ExportedProgram":
  702. return ExportedProgram(
  703. root=graph_module,
  704. graph=graph_module.graph,
  705. graph_signature=graph_signature,
  706. state_dict=state_dict or self.state_dict,
  707. range_constraints=copy.deepcopy(self.range_constraints),
  708. module_call_graph=copy.deepcopy(self._module_call_graph),
  709. example_inputs=self.example_inputs,
  710. verifier=self.verifier,
  711. tensor_constants=self.tensor_constants,
  712. )
  713. def _get_shape_env(gm):
  714. vals = [
  715. node.meta["val"]
  716. for node in gm.graph.nodes
  717. if node.meta.get("val", None) is not None
  718. ]
  719. from torch._guards import detect_fake_mode
  720. fake_mode = detect_fake_mode(vals)
  721. if fake_mode is not None:
  722. return fake_mode.shape_env
  723. for v in vals:
  724. if isinstance(v, torch.SymInt):
  725. return v.node.shape_env
  726. def _get_updated_range_constraints(
  727. gm: torch.fx.GraphModule,
  728. old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None,
  729. _is_executorch: bool = True,
  730. ) -> "Dict[sympy.Symbol, Any]":
  731. # FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764
  732. if _is_executorch:
  733. assert old_range_constraints is None
  734. shape_env = _get_shape_env(gm)
  735. if shape_env is None:
  736. return {}
  737. range_constraints = {
  738. k: v
  739. for k, v in shape_env.var_to_range.items()
  740. if k not in shape_env.replacements
  741. }
  742. # Only when we have an unbacked symint, and it's used as constructor inputs,
  743. # runtime_var_to_range will make a difference compated to var_to_range.
  744. # e.g. [2, oo) -> [0, oo)
  745. for k, v in shape_env.var_to_range.items():
  746. if k not in shape_env.replacements:
  747. range_constraints[k] = v
  748. return range_constraints
  749. assert old_range_constraints is not None
  750. shape_env = _get_shape_env(gm)
  751. if shape_env is None:
  752. return {}
  753. range_constraints = copy.copy(old_range_constraints)
  754. range_constraints = {
  755. k: v for k, v in range_constraints.items() if k not in shape_env.replacements
  756. }
  757. # Only when we have an unbacked symint, and it's used as constructor inputs,
  758. # runtime_var_to_range will make a difference compated to var_to_range.
  759. # e.g. [2, oo) -> [0, oo)
  760. for k, v in shape_env.var_to_range.items():
  761. if k not in shape_env.replacements and k not in range_constraints:
  762. range_constraints[k] = v
  763. return range_constraints
  764. def _create_graph_module_for_export(root, graph):
  765. try:
  766. gm = torch.fx.GraphModule(root, graph)
  767. except SyntaxError:
  768. # If custom objects stored in memory are being used in the graph,
  769. # the generated python code will result in a syntax error on the custom
  770. # object, since it is unable to parse the in-memory object. However
  771. # we can still run the graph eagerly through torch.fx.Interpreter,
  772. # so we will bypass this error.
  773. warnings.warn(
  774. "Unable to execute the generated python source code from "
  775. "the graph. The graph module will no longer be directly callable, "
  776. "but you can still run the ExportedProgram, and if needed, you can "
  777. "run the graph module eagerly using torch.fx.Interpreter."
  778. )
  779. gm = torch.fx.GraphModule(root, torch.fx.Graph())
  780. gm._graph = graph
  781. return gm