unflatten.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126
  1. # mypy: allow-untyped-defs
  2. import abc
  3. import copy
  4. import operator
  5. from collections import defaultdict
  6. from copy import deepcopy
  7. from enum import Enum
  8. from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
  9. import torch
  10. import torch.fx._pytree as fx_pytree
  11. import torch.utils._pytree as pytree
  12. from torch._library.fake_class_registry import FakeScriptObject
  13. from torch.export._tree_utils import reorder_kwargs
  14. from torch.export.exported_program import (
  15. ConstantArgument,
  16. ExportedProgram,
  17. InputKind,
  18. ModuleCallSignature,
  19. SymIntArgument,
  20. TensorArgument,
  21. )
  22. from torch.fx._symbolic_trace import is_fx_tracing
  23. from torch.utils._pytree import GetAttrKey, SequenceKey
  24. __all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"]
  25. class _AttrKind(Enum):
  26. PARAMETER = "parameter"
  27. BUFFER = "buffer"
  28. CONSTANT = "constant"
  29. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  30. # This installs empty Modules where none exist yet if they are subpaths of target
  31. def _assign_attr(
  32. from_obj: Union[torch.Tensor, torch.ScriptObject],
  33. to_module: torch.nn.Module,
  34. target: str,
  35. attr_kind: _AttrKind,
  36. persistent: bool = True,
  37. ):
  38. *prefix, field = target.split(".")
  39. for item in prefix:
  40. t = getattr(to_module, item, None)
  41. if t is None:
  42. t = torch.nn.Module()
  43. setattr(to_module, item, t)
  44. to_module = t
  45. if attr_kind == _AttrKind.PARAMETER:
  46. assert isinstance(from_obj, torch.nn.Parameter)
  47. to_module.register_parameter(field, from_obj)
  48. elif attr_kind == _AttrKind.BUFFER:
  49. assert isinstance(from_obj, torch.Tensor)
  50. to_module.register_buffer(field, from_obj, persistent=persistent)
  51. elif attr_kind == _AttrKind.CONSTANT:
  52. assert not isinstance(
  53. from_obj, FakeScriptObject
  54. ), "FakeScriptObject should only exist during tracing."
  55. assert isinstance(
  56. from_obj,
  57. (
  58. torch.Tensor,
  59. torch.ScriptObject,
  60. ),
  61. )
  62. setattr(to_module, field, from_obj)
  63. class InterpreterModule(torch.nn.Module):
  64. """A module that uses torch.fx.Interpreter to execute instead of the usual
  65. codegen that GraphModule uses. This provides better stack trace information
  66. and makes it easier to debug execution.
  67. """
  68. def __init__(
  69. self,
  70. graph: torch.fx.Graph,
  71. ):
  72. super().__init__()
  73. self.graph = graph
  74. self.graph.owning_module = self
  75. def forward(self, *args, **kwargs):
  76. assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
  77. if torch.compiler.is_dynamo_compiling():
  78. # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
  79. # GraphModule codegen in this instance.
  80. return self.graph_module(*args, **kwargs)
  81. else:
  82. if kwargs:
  83. # Handle **kwargs. FX only natively supports positional
  84. # arguments (through placeholders). So in order to pass in
  85. # kwargs, we must correspond the names of the placeholders with
  86. # the keys in the kwarg dict.
  87. arg_list = list(args)
  88. kwarg_names = self.arg_names[len(arg_list) :]
  89. for kwarg_name in kwarg_names:
  90. if kwarg_name in kwargs:
  91. arg_list.append(kwargs[kwarg_name])
  92. # Assert that the kwargs passed in exactly match the positional
  93. # arguments specified by the GraphModule. This should be
  94. # guaranteed by the unflattening process.
  95. assert len(kwarg_names) == len(kwargs)
  96. assert len(arg_list) == len(self.arg_names)
  97. args = tuple(arg_list)
  98. return torch.fx.Interpreter(self, graph=self.graph).run(
  99. *args, enable_io_processing=False
  100. )
  101. def finalize(self):
  102. # We need to "finalize" because GraphModule populates its own state_dict
  103. # based on the get_attrs observed in the graph. So we need to fully
  104. # construct the graph and call _sink_params before generating this
  105. # GraphModule.
  106. # need to set `graph_module` directly on the dict to avoid it getting
  107. # registered as a submodule.
  108. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  109. self.graph.lint()
  110. # Cache arg names for kwarg handling (see forward())
  111. self.arg_names = []
  112. for node in self.graph.nodes:
  113. if node.op == "placeholder":
  114. self.arg_names.append(node.target)
  115. class FlatArgsAdapter(abc.ABC):
  116. """
  117. Adapts input arguments with ``input_spec`` to align ``target_spec``.
  118. """
  119. @abc.abstractmethod
  120. def adapt(
  121. self,
  122. target_spec: pytree.TreeSpec,
  123. input_spec: pytree.TreeSpec,
  124. input_args: List[Any],
  125. ) -> List[Any]:
  126. """NOTE: This adapter may mutate given ``input_args_with_path``."""
  127. ...
  128. class UnflattenedModule(torch.nn.Module):
  129. def __init__(
  130. self,
  131. export_module: ExportedProgram,
  132. flat_args_adapter: Optional[FlatArgsAdapter] = None,
  133. ):
  134. super().__init__()
  135. if export_module.graph_signature.backward_signature is not None:
  136. raise ValueError("Unflattening on JointExportModule NYI")
  137. fqn_list = [entry.fqn for entry in export_module.module_call_graph]
  138. assert fqn_list[0] == ""
  139. export_graph = deepcopy(export_module.graph)
  140. self.graph_signature = deepcopy(export_module.graph_signature)
  141. self.graph = torch.fx.Graph()
  142. self.module_call_graph = deepcopy(export_module.module_call_graph)
  143. self.flat_args_adapter = flat_args_adapter
  144. # Flag to indicate whether args have been adapted.
  145. self.adapted = False
  146. _inplace_buffer_mutations(export_graph, self.graph_signature)
  147. _outline_submodules(export_graph, self)
  148. self.range_constraints = export_module.range_constraints
  149. self.equality_constraints: List = []
  150. # aliasing/unused param or buffer issues:
  151. # in strict-mode export, dynamo export will deduplicate aliased tensors,
  152. # and ignore unused tensors. For aliasing, this causes issues when some aliases
  153. # are unused, and we're unable to match the placeholder node to the correct FQN.
  154. # This leads to the graph signature potentially having the wrong target FQN,
  155. # and downstream issues where parameters are assigned to the wrong target attribute,
  156. # mismatching the relevant placeholder node in the unflattened module.
  157. # To resolve this we restore (_assign_attr) all aliased/unused tensors in
  158. # the state_dict as module attributes, but only keep the used tensors in the
  159. # graph's forward pass (_sink_params).
  160. state_dict = export_module.state_dict
  161. assigned_params: Set[str] = set() # tracking unused params
  162. id_to_param: Dict[int, torch.nn.Parameter] = {} # handling weight-sharing
  163. for name in self.graph_signature.parameters: # this loop adds used params
  164. param = state_dict[name]
  165. if id(param) not in id_to_param:
  166. id_to_param[id(param)] = torch.nn.Parameter(param.clone())
  167. _assign_attr(
  168. id_to_param[id(param)],
  169. self,
  170. name,
  171. attr_kind=_AttrKind.PARAMETER,
  172. )
  173. assigned_params.add(name)
  174. non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
  175. assigned_buffers: Set[str] = set() # tracking unused buffers
  176. id_to_buffer: Dict[
  177. int, Tuple[torch.nn.Parameter, bool]
  178. ] = {} # handle weight-sharing
  179. for name in self.graph_signature.buffers: # this loop adds used buffers
  180. if name in non_persistent_buffers:
  181. persistent = False
  182. buffer = export_module.constants[name]
  183. else:
  184. persistent = True
  185. buffer = state_dict[name]
  186. if id(buffer) not in id_to_buffer:
  187. id_to_buffer[id(buffer)] = (buffer.clone(), persistent)
  188. _assign_attr(
  189. id_to_buffer[id(buffer)][0],
  190. self,
  191. name,
  192. attr_kind=_AttrKind.BUFFER,
  193. persistent=persistent,
  194. )
  195. assigned_buffers.add(name)
  196. # restore aliased/unused params and buffers
  197. # these appear in state dict but not graph signature
  198. for name, tensor in state_dict.items():
  199. if name in assigned_params or name in assigned_buffers: # already assigned
  200. continue
  201. is_buffer = False
  202. if id(tensor) in id_to_buffer or not isinstance(
  203. tensor, torch.nn.Parameter
  204. ): # aliased buffer
  205. is_buffer = True
  206. if is_buffer:
  207. if (
  208. id(tensor) not in id_to_buffer
  209. ): # this is completely unused (not weight-sharing)
  210. id_to_buffer[id(tensor)] = (
  211. tensor,
  212. True,
  213. ) # assign to respect original model
  214. _assign_attr(
  215. id_to_buffer[id(tensor)][0],
  216. self,
  217. name,
  218. attr_kind=_AttrKind.BUFFER,
  219. persistent=True,
  220. )
  221. else:
  222. if id(tensor) not in id_to_param: # this is unused
  223. id_to_param[id(tensor)] = tensor
  224. _assign_attr(
  225. id_to_param[id(tensor)],
  226. self,
  227. name,
  228. attr_kind=_AttrKind.PARAMETER,
  229. )
  230. # use id map so we don't double-clone aliased constants
  231. id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
  232. for fqn, constant in export_module.constants.items():
  233. if id(constant) not in id_to_const:
  234. if isinstance(constant, torch.Tensor):
  235. constant = constant.clone()
  236. id_to_const[id(constant)] = constant
  237. _constant = id_to_const[id(constant)]
  238. _assign_attr(
  239. _constant,
  240. self,
  241. fqn,
  242. attr_kind=_AttrKind.CONSTANT,
  243. )
  244. # This is to handle parameters/buffers that point to the same tensor
  245. # object id -> list of (node_name, target_name)
  246. consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list)
  247. consts_targets: Set[str] = set()
  248. def add_to_consts_map(obj_id, node_name, target_name):
  249. name_list = consts_map[obj_id]
  250. name_list.append((node_name, target_name))
  251. added_params_buffers: Set[str] = set() # track aliased/unused params, buffers
  252. for s in self.graph_signature.input_specs:
  253. if s.kind == InputKind.PARAMETER or (
  254. s.kind == InputKind.BUFFER and s.persistent
  255. ):
  256. assert hasattr(s.arg, "name")
  257. assert isinstance(s.target, str)
  258. add_to_consts_map(
  259. id(export_module.state_dict[s.target]), s.arg.name, s.target
  260. )
  261. consts_targets.add(s.target)
  262. added_params_buffers.add(s.target)
  263. elif (
  264. (s.kind == InputKind.BUFFER and not s.persistent)
  265. or s.kind == InputKind.CONSTANT_TENSOR
  266. or s.kind == InputKind.CUSTOM_OBJ
  267. ):
  268. assert hasattr(s.arg, "name")
  269. assert isinstance(s.target, str)
  270. add_to_consts_map(
  271. id(export_module.constants[s.target]), s.arg.name, s.target
  272. )
  273. consts_targets.add(s.target)
  274. # add constants that are aliased and don't appear in graph signature
  275. for const_name, const in export_module.constants.items():
  276. if const_name not in consts_targets:
  277. assert (
  278. id(const) in consts_map
  279. ), "Constants should be either aliased or appear in graph signature"
  280. ph_name, _ = consts_map[id(const)][0]
  281. add_to_consts_map(id(const), ph_name, const_name)
  282. added_params_buffers.add(s.target)
  283. # add aliased/unused params and buffers that don't appear in graph signature
  284. for fqn, tensor in export_module.state_dict.items():
  285. if fqn not in added_params_buffers:
  286. if id(tensor) not in consts_map:
  287. # completely unused (no weight-sharing), ignore.
  288. # this weight doesn't appear in graph module,
  289. # so won't cause FQN assignment issues
  290. continue
  291. ph_name, _ = consts_map[id(tensor)][0]
  292. add_to_consts_map(id(tensor), ph_name, fqn)
  293. # node name -> list of possible targets
  294. inputs_to_state: Dict[str, List[str]] = {}
  295. for node_target in consts_map.values():
  296. targets = [t[1] for t in node_target]
  297. for n, _ in node_target:
  298. inputs_to_state[n] = targets
  299. _sink_params(self, inputs_to_state, [])
  300. # Helper function to check input nodes of `module` has been processed.
  301. def check_module_inputs(module, scope):
  302. if hasattr(module, "graph"):
  303. for node in module.graph.nodes:
  304. # sink_params() should turn placeholders into get_attr nodes
  305. # for attributes that are within scope of the current
  306. # module. We allow attributes to remain as placeholders if
  307. # they are inputs in the original module signature, meaning
  308. # they are a parent module's attribute, and therefore out of
  309. # scope of the current module.
  310. if (
  311. node.op == "placeholder"
  312. and node.name in inputs_to_state
  313. and any(
  314. fqn.split(".")[: len(scope)] == scope
  315. for fqn in inputs_to_state[node.name]
  316. ) # matching scope to avoid wrong assert
  317. ):
  318. raise AssertionError(
  319. f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}"
  320. )
  321. # Recursively check the submodules.
  322. for name, submod in module.named_children():
  323. scope.append(name)
  324. check_module_inputs(submod, scope)
  325. # Recurively check all input nodes have been processed.
  326. check_module_inputs(self, [])
  327. # Cache so we don't have to compute this every time.
  328. # NOTE: this needs to be kept in sync with the placeholders in
  329. # self.graph, but currently we have no way to guarantee that.
  330. self.input_placeholders = [
  331. node for node in self.graph.nodes if node.op == "placeholder"
  332. ]
  333. self.check_input_constraints = True
  334. # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
  335. fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
  336. # In the case of legacy IR, we might be missing some modules from metadata.
  337. for name, _ in self.named_modules(remove_duplicate=False):
  338. if name not in fqn_order:
  339. fqn_order[name] = len(fqn_order)
  340. _reorder_submodules(self, fqn_order)
  341. assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list(
  342. fqn_order.keys()
  343. )
  344. def _print_graph(self):
  345. for fqn, mod in self.named_modules():
  346. print(fqn + ":")
  347. if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
  348. print(mod.graph)
  349. def forward(self, *args, **kwargs):
  350. signature = self.module_call_graph[0].signature
  351. reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
  352. flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
  353. (args, reordered_kwargs)
  354. )
  355. flat_args = [x[1] for x in flat_args_with_path]
  356. if is_fx_tracing():
  357. return_val = torch.fx.Interpreter(self, graph=self.graph).run(
  358. *flat_args, enable_io_processing=False
  359. )
  360. # For scalar return value, fx.Graph wraps in a tuple
  361. if isinstance(return_val, tuple) and len(return_val) == 1:
  362. return return_val[0]
  363. return return_val
  364. if in_spec != signature.in_spec:
  365. if not self.adapted:
  366. print(
  367. "Input treespec does not match with exported module's: \n"
  368. f"Input treespec: {in_spec}. ",
  369. f"Exported module treespec: {signature.in_spec}",
  370. )
  371. if self.flat_args_adapter is None:
  372. raise TypeError(
  373. "There is no flat args adapter sepcified. "
  374. "Are you sure you are calling this with the right arguments? "
  375. )
  376. else:
  377. if not self.adapted:
  378. print("Adapting flat arg to match exported module's treespec")
  379. flat_args = self.flat_args_adapter.adapt(
  380. target_spec=signature.in_spec,
  381. input_spec=in_spec,
  382. input_args=flat_args,
  383. )
  384. self.adapted = True
  385. if len(flat_args) != signature.in_spec.num_leaves:
  386. raise TypeError(
  387. f"Flat args adaption failed, number of args mismatch "
  388. f"Adatped: {len(flat_args)} \n"
  389. f"Exported module: {signature.in_spec.num_leaves}"
  390. )
  391. if self.check_input_constraints:
  392. # Import here to avoid an unfortunate circular dependency.
  393. # TODO(suo): untangle this.
  394. from torch._export.utils import _check_input_constraints_for_graph
  395. if self.adapted is True:
  396. # TODO(suo): The FlatArgsAdapter returns a list of flat args,
  397. # which we don't have keypaths for. For now, just create a dummy
  398. # keypath to associate with the arg.
  399. new_flat_args_with_path = [ # type: ignore[var-annotated]
  400. ((SequenceKey(idx=0), GetAttrKey(name="<unknown location>")), arg)
  401. for arg in flat_args
  402. ]
  403. else:
  404. new_flat_args_with_path = flat_args_with_path # type: ignore[assignment]
  405. _check_input_constraints_for_graph(
  406. self.input_placeholders, new_flat_args_with_path, self.range_constraints
  407. )
  408. tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
  409. *flat_args, enable_io_processing=False
  410. )
  411. return pytree.tree_unflatten(tree_out, signature.out_spec)
  412. def unflatten(
  413. module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None
  414. ) -> UnflattenedModule:
  415. """Unflatten an ExportedProgram, producing a module with the same module
  416. hierarchy as the original eager module. This can be useful if you are trying
  417. to use :mod:`torch.export` with another system that expects a module
  418. hierachy instead of the flat graph that :mod:`torch.export` usually produces.
  419. .. note:: The args/kwargs of unflattened modules will not necessarily match
  420. the eager module, so doing a module swap (e.g. :code:`self.submod =
  421. new_mod`) will not necessarily work. If you need to swap a module out, you
  422. need to set the :code:`preserve_module_call_signature` parameter of
  423. :func:`torch.export.export`.
  424. Args:
  425. module (ExportedProgram): The ExportedProgram to unflatten.
  426. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
  427. Returns:
  428. An instance of :class:`UnflattenedModule`, which has the same module
  429. hierarchy as the original eager module pre-export.
  430. """
  431. return UnflattenedModule(module, flat_args_adapter)
  432. def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None:
  433. """Transform buffer mutations from their functionalized form into a copy_
  434. node in the graph.
  435. Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code:
  436. def forward(self, x):
  437. self.buffer += x
  438. return x * x
  439. Will become a graph that looks like:
  440. def forward(self, buffer, x):
  441. mutated_buffer = aten.add(buffer, x)
  442. mul = aten.mul(x, x)
  443. return (mutated_buffer, mul)
  444. We want to inplace this into something that looks like the original eager code:
  445. def forward(self, buffer, x):
  446. mutated_buffer = aten.add(buffer, x)
  447. buffer.copy_(mutated_buffer)
  448. mul = aten.mul(x, x)
  449. return (mul,)
  450. """
  451. output_node = next(iter(reversed(graph.nodes)))
  452. assert output_node.op == "output" and len(output_node.args) == 1
  453. return_args = output_node.args[0]
  454. mutation_node_to_buffer = graph_signature.buffers_to_mutate
  455. mutations = return_args[: len(mutation_node_to_buffer)]
  456. buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()}
  457. input_name_to_node = {
  458. node.name: node for node in graph.nodes if node.op == "placeholder"
  459. }
  460. for mutation in mutations:
  461. buffer_name = mutation_node_to_buffer[mutation.name]
  462. input_name = buffers_to_inputs[buffer_name]
  463. input_node = input_name_to_node[input_name]
  464. with graph.inserting_after(mutation):
  465. new_node = graph.create_node(
  466. "call_function", torch.ops.aten.copy_, (input_node, mutation)
  467. )
  468. for k, v in mutation.meta.items():
  469. new_node.meta[k] = v
  470. # Replace all uses of the previously functional mutation with our copy_ output.
  471. mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
  472. # Remove the mutated buffer from the graph outputs, since we don't need to
  473. # thread it through anymore. We don't need to handle the inputs, which will
  474. # be handled by _sink_params.
  475. user_outputs = tuple(
  476. return_args[len(mutation_node_to_buffer) :],
  477. )
  478. output_node.args = ((user_outputs),)
  479. def _is_prefix(candidate, target):
  480. """Check whether `candidate` is a prefix of `target`."""
  481. return len(candidate) < len(target) and target[: len(candidate)] == candidate
  482. def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
  483. if parent_fqn == "":
  484. # Handle the root module correctly.
  485. return child_fqn
  486. parent_split = parent_fqn.split(".")
  487. child_split = child_fqn.split(".")
  488. assert (
  489. child_split[: len(parent_split)] == parent_split
  490. ), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'"
  491. return ".".join(child_split[len(parent_split) :])
  492. def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
  493. def graph_dump(graph: torch.fx.Graph) -> str:
  494. ret = []
  495. nodes_idx: Dict[int, int] = {}
  496. def arg_dump(arg) -> str:
  497. if isinstance(arg, torch.fx.Node):
  498. return "%" + str(nodes_idx[id(arg)])
  499. return str(arg)
  500. for i, node in enumerate(graph.nodes):
  501. args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
  502. args_dump += [
  503. f"{key}={value}"
  504. for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
  505. ]
  506. target = node.target if node.op == "call_function" else ""
  507. ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
  508. nodes_idx[id(node)] = i
  509. return "\n".join(ret)
  510. assert graph_dump(x.graph) == graph_dump(y.graph)
  511. def _add_spec(gm: torch.nn.Module, spec) -> str:
  512. i = 0
  513. while hasattr(gm, f"_spec_{i}"):
  514. i += 1
  515. name = f"_spec_{i}"
  516. setattr(gm, name, spec)
  517. return name
  518. def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node:
  519. name = _add_spec(gm, spec)
  520. spec_node = gm.graph.get_attr(name)
  521. return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
  522. def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node:
  523. name = _add_spec(gm, spec)
  524. spec_node = gm.graph.get_attr(name)
  525. return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
  526. def _get_submodule(mod: torch.nn.Module, target: str):
  527. *prefix, field = target.split(".")
  528. for item in prefix:
  529. submod = getattr(mod, item, None)
  530. if submod is None:
  531. return None
  532. if not isinstance(submod, torch.nn.Module):
  533. return None
  534. mod = submod
  535. return getattr(mod, field, None)
  536. def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module):
  537. *prefix, field = target.split(".")
  538. for item in prefix:
  539. submod = getattr(mod, item, None)
  540. if submod is None:
  541. submod = torch.nn.Module()
  542. setattr(mod, item, submod)
  543. if not isinstance(submod, torch.nn.Module):
  544. return False
  545. mod = submod
  546. mod.add_module(field, module_to_add)
  547. class _ModuleFrame:
  548. def __init__(
  549. self,
  550. flat_graph,
  551. nodes,
  552. seen_nodes,
  553. seen_modules,
  554. parent,
  555. module_stack,
  556. module_id,
  557. module_call_graph: Dict[str, ModuleCallSignature],
  558. module: Optional[torch.nn.Module] = None,
  559. ):
  560. self.flat_graph = flat_graph
  561. self.nodes = nodes
  562. self.seen_nodes = seen_nodes
  563. self.seen_modules = seen_modules
  564. self.parent = parent
  565. self.module_stack = module_stack
  566. self.module_id = module_id
  567. self.module_call_graph = module_call_graph
  568. self.verbose = False
  569. self.fqn = self.module_stack[-1]
  570. if module is not None:
  571. self.module = module
  572. else:
  573. self.module = InterpreterModule(torch.fx.Graph())
  574. if self.module_id in self.seen_modules:
  575. self.cached_graph_module = self.seen_modules[self.module_id]
  576. else:
  577. self.cached_graph_module = None
  578. self.seen_modules[self.module_id] = self.module
  579. self.graph = self.module.graph
  580. # Mapping of nodes in the flat graph to nodes in this graph.
  581. self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {}
  582. self.node_to_placeholder = {}
  583. self.parent_call_module: Optional[torch.fx.Node] = None
  584. if parent is not None:
  585. accessor = _compute_accessor(parent.fqn, self.fqn)
  586. _add_submodule(
  587. parent.module,
  588. accessor,
  589. (
  590. self.module
  591. if self.cached_graph_module is None
  592. else self.cached_graph_module
  593. ),
  594. )
  595. self.parent_call_module = parent.graph.call_module(accessor)
  596. signature = module_call_graph.get(self.fqn)
  597. if signature is not None and self.parent is not None:
  598. assert signature.in_spec.num_children == 2
  599. args_spec = signature.in_spec.children_specs[0]
  600. kwargs_spec = signature.in_spec.children_specs[1]
  601. assert args_spec.context is None
  602. assert kwargs_spec.context is not None
  603. with self.graph.inserting_after(None):
  604. arg_nodes = []
  605. for idx in range(args_spec.num_children):
  606. arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))
  607. kwarg_nodes = {}
  608. for name in kwargs_spec.context:
  609. kwarg_nodes[name] = self.graph.placeholder(name)
  610. flat_args = _generate_flatten(
  611. self.module,
  612. (tuple(arg_nodes), kwarg_nodes),
  613. signature.in_spec,
  614. )
  615. for idx, arg in enumerate(signature.inputs):
  616. flat_arg_node = self.graph.create_node(
  617. op="call_function",
  618. target=operator.getitem,
  619. args=(flat_args, idx),
  620. name=(
  621. arg.name
  622. if not isinstance(arg, ConstantArgument)
  623. else f"_constant_{idx}"
  624. ),
  625. )
  626. if isinstance(arg, ConstantArgument):
  627. continue
  628. flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
  629. self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node
  630. with self.parent.graph.inserting_before(self.parent_call_module):
  631. input_nodes: List[Optional[torch.fx.Node]] = []
  632. for input in signature.inputs:
  633. if isinstance(input, ConstantArgument) and input.value is None:
  634. input_nodes.append(None)
  635. else:
  636. assert isinstance(input, (TensorArgument, SymIntArgument))
  637. input_nodes.append(
  638. self.parent.remap_input(self.seen_nodes[input.name])
  639. )
  640. inputs_node = _generate_unflatten(
  641. self.parent.module,
  642. input_nodes,
  643. signature.in_spec,
  644. )
  645. args_node = self.parent.graph.call_function(
  646. operator.getitem, (inputs_node, 0)
  647. )
  648. kwargs_node = self.parent.graph.call_function(
  649. operator.getitem, (inputs_node, 1)
  650. )
  651. arg_nodes = [
  652. self.parent.graph.call_function(operator.getitem, (args_node, i))
  653. for i in range(args_spec.num_children)
  654. ]
  655. kwarg_nodes = {
  656. k: self.parent.graph.call_function(
  657. operator.getitem, (kwargs_node, k)
  658. )
  659. for k in kwargs_spec.context
  660. }
  661. assert self.parent_call_module is not None
  662. self.parent_call_module.args = tuple(arg_nodes)
  663. self.parent_call_module.kwargs = kwarg_nodes
  664. def add_placeholder(self, x):
  665. assert self.fqn != "", f"Cannot add placeholder {x} to root module"
  666. assert x.graph is self.flat_graph
  667. # x is not in subgraph, create a new placeholder for subgraph
  668. with self.graph.inserting_before(None):
  669. placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
  670. # copy all meta fields, even if some fields might be irrelvant for
  671. # the placeholder node
  672. placeholder_node.meta = copy.copy(x.meta)
  673. self.node_to_placeholder[x] = placeholder_node
  674. def remap_input(self, x):
  675. assert x.graph is self.flat_graph
  676. if x in self.node_map:
  677. return self.node_map[x]
  678. if x not in self.node_to_placeholder:
  679. self.add_placeholder(x)
  680. if self.parent_call_module is not None:
  681. # Important to *prepend* the output to match how we are
  682. # inserting placeholder nodes.
  683. self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
  684. return self.node_to_placeholder[x]
  685. def finalize_outputs(self):
  686. orig_outputs = []
  687. signature = self.module_call_graph.get(self.fqn)
  688. if signature is not None and self.parent is not None:
  689. for output in signature.outputs:
  690. if isinstance(output, (TensorArgument, SymIntArgument)):
  691. orig_outputs.append(self.seen_nodes[output.name])
  692. else:
  693. raise RuntimeError(
  694. f"Unsupported data type for output node: {output}"
  695. )
  696. tree_out_node = _generate_unflatten(
  697. self.module,
  698. tuple(
  699. self.node_map[self.seen_nodes[output.name]]
  700. for output in orig_outputs
  701. ),
  702. signature.out_spec,
  703. )
  704. parent_out: Optional[torch.fx.Node] = _generate_flatten(
  705. self.parent.module, self.parent_call_module, signature.out_spec
  706. )
  707. graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node
  708. else:
  709. graph_outputs = []
  710. # Iterate through nodes we have copied into self.graph.
  711. for orig_node in self.node_map.keys():
  712. for user_node in orig_node.users:
  713. if user_node.name not in self.seen_nodes:
  714. # external user node, need to expose as an output
  715. orig_outputs.append(orig_node)
  716. graph_outputs.append(self.node_map[orig_node])
  717. break
  718. parent_out = self.parent_call_module
  719. if len(graph_outputs) == 1:
  720. graph_outputs = graph_outputs[0]
  721. assert isinstance(graph_outputs, (list, torch.fx.Node))
  722. self.graph.output(graph_outputs)
  723. # Rewrite outputs in parent module
  724. if parent_out is None:
  725. return
  726. parent_out.meta["val"] = (
  727. graph_outputs.meta.get("val")
  728. if isinstance(graph_outputs, torch.fx.Node)
  729. else [o.meta.get("val") for o in graph_outputs]
  730. )
  731. if len(orig_outputs) == 1 and signature is None:
  732. self.parent.node_map[orig_outputs[0]] = parent_out
  733. else:
  734. for i, orig_output in enumerate(orig_outputs):
  735. # Use Proxy to record getitem access.
  736. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
  737. proxy_out.meta["val"] = orig_output.meta.get("val")
  738. self.parent.node_map[orig_output] = proxy_out
  739. if self.cached_graph_module is not None:
  740. _verify_graph_equivalence(self.cached_graph_module, self.module)
  741. def copy_node(self, node):
  742. self.print("copying", node.format_node())
  743. self.node_map[node] = self.graph.node_copy(node, self.remap_input)
  744. self.seen_nodes[node.name] = node
  745. def run_outer(self):
  746. i = 0
  747. for node in self.flat_graph.nodes:
  748. self.print(i, node.meta.get("nn_module_stack"), node.format_node())
  749. i += 1
  750. # Copy all graph inputs
  751. node_idx: int = 0
  752. node = self.nodes[node_idx]
  753. while node.op == "placeholder":
  754. self.copy_node(node)
  755. node_idx += 1
  756. node = self.nodes[node_idx]
  757. self.run_from(node_idx)
  758. # Copy graph outputs
  759. for node in self.flat_graph.nodes:
  760. if node.op == "output":
  761. self.copy_node(node)
  762. def print(self, *args, **kwargs):
  763. if self.verbose:
  764. print(*args, **kwargs)
  765. def run_from(self, node_idx):
  766. module_idx = 0
  767. # Walk through the graph, building up a new graph with the right submodules
  768. while node_idx < len(self.nodes):
  769. node = self.nodes[node_idx]
  770. assert node.op != "placeholder"
  771. self.print()
  772. self.print("STEP", node_idx, node.format_node())
  773. self.print(self.module_stack)
  774. if node.op == "output":
  775. if len(self.module_stack) == 1:
  776. # We want the output node of the original graph to be handled
  777. # specially by the outermost stack frame (in run_outer). So
  778. # skip finalization here.
  779. return node_idx
  780. # We've reached the end of the graph. Wrap up all the existing stack frames.
  781. self.finalize_outputs()
  782. return node_idx
  783. if len(node.meta.get("nn_module_stack", {})) == 0:
  784. raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
  785. nn_module_stack = node.meta["nn_module_stack"]
  786. from torch._export.passes._node_metadata_hook import (
  787. _EMPTY_NN_MODULE_STACK_KEY,
  788. )
  789. if (
  790. len(nn_module_stack) == 1
  791. and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
  792. ):
  793. # Empty case from the node_metadata_hook
  794. node_module_stack = self.module_stack
  795. else:
  796. node_module_stack = [
  797. path for path, ty in node.meta["nn_module_stack"].values()
  798. ]
  799. if node_module_stack[: len(self.module_stack)] != self.module_stack:
  800. # This means that the current module is done executing and the
  801. # current node is the beginning of a new module.
  802. #
  803. # In this case, we should finalize this module and return without
  804. # incrementing the node counter.
  805. self.finalize_outputs()
  806. self.print("outlining", self.fqn)
  807. self.print(self.graph)
  808. return node_idx
  809. assert node_module_stack is not None
  810. if _is_prefix(self.module_stack, node_module_stack):
  811. # This means that the current node represents the execution of a new
  812. # module.
  813. next_module = node_module_stack[len(self.module_stack)]
  814. self.print("Creating new stack frame for", next_module)
  815. # Run a nested version of module outliner from the current node
  816. # counter. Once it is complete, continue from that point.
  817. node_idx = _ModuleFrame(
  818. self.flat_graph,
  819. self.nodes,
  820. self.seen_nodes,
  821. self.seen_modules,
  822. self,
  823. self.module_stack + [next_module],
  824. list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],
  825. self.module_call_graph,
  826. ).run_from(node_idx)
  827. module_idx += 1
  828. continue
  829. # The only remaining possibility is that we are in the right stack
  830. # frame. Copy the node into this frame's graph and increment the node counter.
  831. assert node_module_stack == self.module_stack
  832. self.copy_node(node)
  833. node_idx += 1
  834. def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
  835. seen_nodes: Dict[str, torch.fx.Node] = {}
  836. seen_modules: Dict[int, torch.nn.Module] = {}
  837. _ModuleFrame(
  838. orig_graph,
  839. tuple(orig_graph.nodes),
  840. seen_nodes,
  841. seen_modules,
  842. None,
  843. [""],
  844. "",
  845. {
  846. entry.fqn: entry.signature
  847. for entry in root_module.module_call_graph
  848. if entry.signature
  849. },
  850. module=root_module,
  851. ).run_outer()
  852. def _reorder_submodules(
  853. parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
  854. ):
  855. # TODO Can be optimized by adding submodules ahead of time.
  856. if prefix == "":
  857. for fqn in list(fqn_order.keys())[1:]:
  858. if _get_submodule(parent, fqn) is None:
  859. _add_submodule(parent, fqn, torch.nn.Module())
  860. children = []
  861. for name, child in list(parent._modules.items()):
  862. if child is None:
  863. continue
  864. fqn = prefix + name
  865. _reorder_submodules(child, fqn_order, prefix=fqn + ".")
  866. delattr(parent, name)
  867. children.append((fqn_order[fqn], name, child))
  868. children.sort(key=operator.itemgetter(0))
  869. for _, name, child in children:
  870. parent.register_module(name, child)
  871. def _sink_params(
  872. module: torch.nn.Module,
  873. inputs_to_state: Dict[str, List[str]],
  874. scope: List[str],
  875. ):
  876. """Sink params, buffers, and constants from graph inputs into get_attr nodes.
  877. Exported modules are purely functional, so they pass their parameters and
  878. buffers in as inputs to the graph.
  879. To replicate eager's semantics, we need to get them from the module state
  880. via get_attr instead.
  881. module: GraphModule, potentially containining nested submodules.
  882. inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
  883. scope: tracks where we are in the module hierarchy, so that we can emit the
  884. right `getattr(self, "foo.bar")` calls, etc.
  885. """
  886. # This dict records inputs removed by child modules.
  887. # Maps the module object id to the list of placeholder node names
  888. # in the child module that were removed.
  889. module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list)
  890. # We need to use _modules here instead of named_children(), because we
  891. # explicitly want duplicate modules to show up in the traversal.
  892. for name, submodule in module._modules.items():
  893. submod_id_to_inputs_removed = _sink_params(
  894. cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]
  895. )
  896. for k, v in submod_id_to_inputs_removed.items():
  897. module_id_to_inputs_removed[k].extend(v)
  898. if not hasattr(module, "graph"):
  899. # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
  900. return module_id_to_inputs_removed
  901. graph = module.graph
  902. inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
  903. the_last_input = inputs[-1]
  904. # Also remove from call_module nodes
  905. call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
  906. for node in call_module_nodes:
  907. submodule = _recursive_getattr(module, node.target.split("."))
  908. # remove placeholder from call_module node arguments, only if we've
  909. # erased the placeholder node in the corresponding _sink_params() call
  910. if submodule is not None and id(submodule) in module_id_to_inputs_removed:
  911. node.args = tuple(
  912. filter(
  913. lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
  914. node.args,
  915. )
  916. )
  917. # Filter out inputs_to_state corresponding to current scope.
  918. inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {}
  919. for node in inputs:
  920. if node.name not in inputs_to_state:
  921. continue
  922. state_name = None
  923. for sn in inputs_to_state[node.name]:
  924. sn_split = sn.split(".")
  925. if sn_split[: len(scope)] == scope:
  926. state_name = sn_split
  927. break
  928. # If there's a mismatch beteewn scope name and state name, then
  929. # there must be multuple scopes pointing to the same state name,
  930. # meaning some modules are shared. In such case, we can simply skip
  931. # updating the current node because another later iteration will
  932. # take care of this input node when the unique match between scope
  933. # and state name occurs. To make sure this always happen, we should
  934. # enforce the invariant that no placeholder node in the unflattened
  935. # graph appears in inputs_to_state dict, which means all the extra
  936. # input nodes have been handled.
  937. if state_name is None:
  938. continue
  939. inputs_to_state_of_scope[node] = state_name
  940. # Record name of remove inputs for return purpose.
  941. inputs_removed: List[str] = []
  942. for node, state_name in inputs_to_state_of_scope.items():
  943. if len(node.users) > 0:
  944. attr_path = state_name[len(scope) :]
  945. state_attr = _recursive_getattr(module, attr_path)
  946. assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
  947. # Make sure the newly created get_attr node is placed after the last placeholder node
  948. with graph.inserting_after(the_last_input):
  949. new_node = graph.create_node("get_attr", ".".join(attr_path))
  950. node.replace_all_uses_with(new_node, propagate_meta=True)
  951. graph.erase_node(node)
  952. inputs_removed.append(node.name)
  953. if isinstance(module, InterpreterModule):
  954. module.finalize()
  955. return {id(module): inputs_removed}
  956. def _recursive_getattr(obj, attr_path):
  957. for attr in attr_path:
  958. if not hasattr(obj, attr):
  959. return None
  960. obj = getattr(obj, attr)
  961. return obj