utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # mypy: allow-untyped-defs
  2. import ast
  3. import dataclasses
  4. import inspect
  5. import math
  6. import operator
  7. import re
  8. from inspect import Parameter
  9. from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
  10. import torch
  11. from torch._subclasses.fake_tensor import FakeTensor
  12. from torch.export import ExportedProgram
  13. from torch.export.exported_program import (
  14. _name_hoo_subgraph_placeholders,
  15. _rename_without_collisions,
  16. )
  17. from torch.export.graph_signature import InputKind, OutputKind
  18. from torch.utils._pytree import (
  19. _register_pytree_node,
  20. Context,
  21. FlattenFunc,
  22. FromDumpableContextFn,
  23. GetAttrKey,
  24. KeyPath,
  25. keystr,
  26. MappingKey,
  27. SequenceKey,
  28. ToDumpableContextFn,
  29. tree_flatten_with_path,
  30. UnflattenFunc,
  31. )
  32. placeholder_prefixes = {
  33. InputKind.USER_INPUT: "",
  34. InputKind.PARAMETER: "p_",
  35. InputKind.BUFFER: "b_",
  36. InputKind.CONSTANT_TENSOR: "c_",
  37. InputKind.CUSTOM_OBJ: "obj_",
  38. InputKind.TOKEN: "token",
  39. }
  40. def _check_input_constraints_for_graph(
  41. input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
  42. ):
  43. def get_keystr(key_path: KeyPath) -> str:
  44. """For a given index into the flat_args, return a human readable string
  45. describing how to access it, e.g. "*args["foo"][0].bar"
  46. """
  47. # Prefix the keypath with "*args" or "**kwargs" to make it clearer where
  48. # the arguments come from. Ultimately we ought to serialize the
  49. # original arg names for the best error message here.
  50. args_kwargs_key_path = key_path[0]
  51. assert isinstance(args_kwargs_key_path, SequenceKey)
  52. if args_kwargs_key_path.idx == 0:
  53. return f"*args{keystr(key_path[1:])}"
  54. else:
  55. kwarg_key = key_path[1]
  56. assert isinstance(kwarg_key, MappingKey)
  57. name = str(kwarg_key)[1:-1] # get rid of the enclosed []
  58. return f"{name}{keystr(key_path[2:])}"
  59. import sympy
  60. from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
  61. _convert_range_to_int,
  62. )
  63. from torch.utils._sympy.solve import try_solve
  64. if len(flat_args_with_path) != len(input_placeholders):
  65. raise RuntimeError(
  66. "Unexpected number of inputs "
  67. f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})"
  68. )
  69. # NOTE: export already guarantees that the same symbol is used in metadata
  70. # for all InputDims related by equality constraints, so we can just unify
  71. # symbols with given input dimension values to check equality constraints.
  72. unification_map: Dict[sympy.Symbol, Any] = {}
  73. for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
  74. node_val = node.meta.get("val")
  75. if isinstance(node_val, FakeTensor):
  76. if not isinstance(arg, torch.Tensor):
  77. raise RuntimeError(
  78. f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}",
  79. )
  80. if len(node_val.shape) != len(arg.shape):
  81. raise RuntimeError(
  82. f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape "
  83. f"(expected {node_val.shape}, got {arg.shape})"
  84. )
  85. for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
  86. # TODO(avik): Assert the following property in the IR verifier:
  87. # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr
  88. if (
  89. isinstance(node_dim, torch.SymInt)
  90. and len(node_dim.node.expr.free_symbols) == 1
  91. ):
  92. symbol = next(iter(node_dim.node.expr.free_symbols))
  93. if symbol in unification_map:
  94. existing_dim = node_dim.node.expr.subs(unification_map)
  95. if arg_dim != existing_dim:
  96. raise RuntimeError(
  97. f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
  98. f"{existing_dim}, but got {arg_dim}",
  99. )
  100. else:
  101. if (
  102. isinstance(arg_dim, torch.SymInt)
  103. and not arg_dim.node.expr.is_number
  104. ):
  105. # This can happen when, say, arg is a fake tensor.
  106. # We do not run checks on symbolic shapes of fake inputs as
  107. # such checks can affect the shape env.
  108. pass
  109. else:
  110. solution = try_solve(
  111. sympy.Eq(node_dim.node.expr, arg_dim), symbol
  112. )
  113. if solution is None:
  114. raise RuntimeError( # noqa: B904
  115. f"Expected input {node.name}.shape[{j}] = {arg_dim} to be "
  116. f"of the form {node_dim.node.expr}, where {symbol} is an integer"
  117. )
  118. else:
  119. unification_map[symbol] = int(solution[1])
  120. if node_dim.node.expr in range_constraints:
  121. min_val, max_val = _convert_range_to_int(
  122. range_constraints[node_dim.node.expr]
  123. )
  124. # NOTE: we allow dimensions to be 0/1 at runtime
  125. if min_val > 2:
  126. if arg_dim < min_val:
  127. raise RuntimeError(
  128. f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= "
  129. f"{min_val}, but got {arg_dim}",
  130. )
  131. if max_val < math.inf:
  132. if arg_dim > max_val:
  133. raise RuntimeError(
  134. f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= "
  135. f"{max_val}, but got {arg_dim}",
  136. )
  137. else:
  138. if arg_dim != node_dim:
  139. if isinstance(
  140. node_dim, torch.SymInt
  141. ): # this means we deferred a guard from export analysis to runtime, let this pass
  142. continue
  143. raise RuntimeError(
  144. f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
  145. f"{node_dim}, but got {arg_dim}",
  146. )
  147. elif isinstance(node_val, (int, float, str)):
  148. if type(arg) != type(node_val) or arg != node_val:
  149. raise RuntimeError(
  150. f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
  151. )
  152. def register_dataclass_as_pytree_node(
  153. cls: Type[Any],
  154. flatten_fn: Optional[FlattenFunc] = None,
  155. unflatten_fn: Optional[UnflattenFunc] = None,
  156. *,
  157. serialized_type_name: Optional[str] = None,
  158. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  159. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  160. return_none_fields: bool = False,
  161. ) -> None:
  162. assert dataclasses.is_dataclass(
  163. cls
  164. ), f"Only dataclasses can be registered with this function: {cls}"
  165. def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
  166. flattened = []
  167. flat_names = []
  168. none_names = []
  169. for f in dataclasses.fields(obj):
  170. name, val = f.name, getattr(obj, f.name)
  171. if val is not None or return_none_fields:
  172. flattened.append(val)
  173. flat_names.append(name)
  174. else:
  175. none_names.append(name)
  176. return flattened, [flat_names, none_names]
  177. def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
  178. flat_names, none_names = context
  179. return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
  180. def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]:
  181. flattened, (flat_names, none_names) = flatten_fn(obj) # type: ignore[misc]
  182. return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
  183. flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
  184. unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
  185. if (to_dumpable_context is None) ^ (from_dumpable_context is None):
  186. raise ValueError(
  187. f"Both to_dumpable_context and from_dumpable_context for {cls} must "
  188. "be None or registered."
  189. )
  190. _register_pytree_node(
  191. cls,
  192. flatten_fn,
  193. unflatten_fn,
  194. serialized_type_name=serialized_type_name,
  195. flatten_with_keys_fn=default_flatten_fn_with_keys,
  196. to_dumpable_context=to_dumpable_context,
  197. from_dumpable_context=from_dumpable_context,
  198. )
  199. def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
  200. """
  201. Checks if the given node is a parameter within the exported program
  202. """
  203. return node.name in program.graph_signature.inputs_to_parameters
  204. def get_param(
  205. program: ExportedProgram,
  206. node: torch.fx.Node,
  207. ) -> Optional[torch.nn.Parameter]:
  208. """
  209. Returns the parameter associated with the given node in the exported program.
  210. Returns None if the node is not a parameter within the exported program
  211. """
  212. if is_param(program, node):
  213. parameter_name = program.graph_signature.inputs_to_parameters[node.name]
  214. return program.state_dict[parameter_name]
  215. return None
  216. def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
  217. """
  218. Checks if the given node is a buffer within the exported program
  219. """
  220. return node.name in program.graph_signature.inputs_to_buffers
  221. def get_buffer(
  222. program: ExportedProgram,
  223. node: torch.fx.Node,
  224. ) -> Optional[torch.Tensor]:
  225. """
  226. Returns the buffer associated with the given node in the exported program.
  227. Returns None if the node is not a buffer within the exported program
  228. """
  229. if is_buffer(program, node):
  230. buffer_name = program.graph_signature.inputs_to_buffers[node.name]
  231. if buffer_name in program.graph_signature.non_persistent_buffers:
  232. return program.constants[buffer_name]
  233. else:
  234. return program.state_dict[buffer_name]
  235. return None
  236. def is_lifted_tensor_constant(
  237. program: ExportedProgram,
  238. node: torch.fx.Node,
  239. ) -> bool:
  240. """
  241. Checks if the given node is a lifted tensor constant within the exported program
  242. """
  243. return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
  244. def get_lifted_tensor_constant(
  245. program: ExportedProgram,
  246. node: torch.fx.Node,
  247. ) -> Optional[torch.Tensor]:
  248. """
  249. Returns the lifted tensor constant associated with the given node in the exported program.
  250. Returns None if the node is not a lifted tensor constant within the exported program
  251. """
  252. if is_lifted_tensor_constant(program, node):
  253. lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[
  254. node.name
  255. ]
  256. return program.constants[lifted_tensor_name]
  257. return None
  258. def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule:
  259. """
  260. Splits the graph module into multiple submodules based on the node_call_back.
  261. The node_call_back should return True if the node is a delimiter. Delimiter will be
  262. the first node in the next submodule.
  263. """
  264. from torch.fx.passes.split_module import split_module
  265. split_map = {}
  266. split_id = 0
  267. for node in gm.graph.nodes:
  268. if node_call_back(node):
  269. split_id += 1
  270. split_map[node] = split_id
  271. new_gm = split_module(
  272. gm,
  273. gm,
  274. lambda node: split_map[node],
  275. keep_original_order=True,
  276. keep_original_node_name=True,
  277. )
  278. # Keep the codegen from original graph module to preserve e.g. pytree info.
  279. new_gm.graph._codegen = gm.graph._codegen
  280. new_gm.recompile()
  281. return new_gm
  282. def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
  283. """Returns the nodes that match the node_call_back as a list."""
  284. return [node for node in nodes if node_call_back(node)]
  285. def nodes_first(
  286. nodes: List[torch.fx.Node], node_call_back=None
  287. ) -> Optional[torch.fx.Node]:
  288. """
  289. Returns the first node that matches the node_call_back. If no node matches, returns None.
  290. When node_call_back is None, returns the first node in the node list.
  291. """
  292. ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True)
  293. if len(ret) > 0:
  294. return ret[0]
  295. return None
  296. def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
  297. """Returns the number of nodes that match the node_call_back."""
  298. return len(nodes_filter(nodes, node_call_back))
  299. def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
  300. """
  301. Sequentially visit the nodes list and invoke node_call_back on each element.
  302. Returns the nodes list after the node_call_back is invoked on each element.
  303. """
  304. for node in nodes:
  305. node_call_back(node)
  306. return nodes
  307. def node_replace_(
  308. old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False
  309. ) -> None:
  310. """
  311. Replace all uses of old_node with new_node.
  312. """
  313. old_node.replace_all_uses_with(new_node)
  314. if delete_old:
  315. old_node.users.clear()
  316. old_node.graph.erase_node(old_node)
  317. def node_inline_(call_mod_node: torch.fx.Node) -> None:
  318. """
  319. Inline the submodule of the given node into the parent module.
  320. Note: we only support the case where submodule takes tensors inputs.
  321. """
  322. assert call_mod_node.op == "call_module"
  323. gm = call_mod_node.graph.owning_module
  324. assert isinstance(call_mod_node.target, str)
  325. sub_gm = getattr(gm, call_mod_node.target)
  326. phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder")
  327. body = (
  328. node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output")
  329. )
  330. output = [node for node in sub_gm.graph.nodes if node.op == "output"]
  331. for ph, arg in zip(phs, call_mod_node.args):
  332. assert isinstance(arg, torch.fx.Node)
  333. node_replace_(ph, arg, delete_old=True)
  334. with gm.graph.inserting_before(call_mod_node):
  335. for node in body:
  336. new_node = gm.graph.node_copy(node)
  337. node_replace_(node, new_node, delete_old=True)
  338. if len(output) > 0:
  339. assert len(output) == 1 and len(output[0].args) == 1
  340. new_output = output[0].args[0]
  341. if isinstance(new_output, torch.fx.Node):
  342. node_replace_(call_mod_node, new_output, delete_old=True)
  343. elif isinstance(new_output, (list, tuple)):
  344. # Inline the get_item calls for the output node.
  345. get_item_users = nodes_filter(
  346. list(call_mod_node.users.keys()),
  347. lambda node: node.op == "call_function"
  348. and node.target == operator.getitem,
  349. )
  350. # get_item_node.args[1] is the idx referring to new_output[idx]
  351. nodes_map(
  352. get_item_users,
  353. lambda get_item_node: node_replace_(
  354. get_item_node,
  355. new_output[get_item_node.args[1]],
  356. delete_old=True,
  357. ),
  358. )
  359. call_mod_node.graph.erase_node(call_mod_node)
  360. else:
  361. raise NotImplementedError(
  362. f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes."
  363. )
  364. else:
  365. call_mod_node.graph.erase_node(call_mod_node)
  366. gm.delete_all_unused_submodules()
  367. gm.recompile()
  368. return gm
  369. def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module):
  370. """
  371. Get source code and parse argument names using AST. The function returns
  372. a signature of the forward() function.
  373. # TODO: Directly provide inspect.signature compatible TS-d module.
  374. """
  375. ast_mod = ast.parse(mod.code)
  376. ast_func_def: ast.FunctionDef = ast_mod.body[0] # type: ignore[assignment]
  377. # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments.
  378. arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD}
  379. # Traverse all argument types in AST tree and create associated parameters.
  380. param_list = []
  381. for arg_type, param_type in arg_type_map.items():
  382. arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)]
  383. for arg_name in arg_name_list:
  384. if arg_name == "self":
  385. continue # Skip self argument.
  386. param_list.append(inspect.Parameter(arg_name, param_type))
  387. return inspect.Signature(parameters=param_list)
  388. def _bind_signature_to_inputs(mod, fake_args, fake_kwargs):
  389. if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)):
  390. sig = _get_torch_jit_trace_forward_signature(mod)
  391. # Sanity check for placeholder names coming from TorchScript.
  392. assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), (
  393. "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() "
  394. "are not supported in _get_torch_jit_trace_forward_signature"
  395. )
  396. else:
  397. sig = inspect.signature(mod.forward)
  398. return sig.bind(*fake_args, **fake_kwargs).arguments
  399. def placeholder_naming_pass(
  400. gm: torch.fx.GraphModule,
  401. export_graph_signature: torch.export.ExportGraphSignature,
  402. mod: torch.nn.Module,
  403. fake_args,
  404. fake_kwargs,
  405. fake_params_buffers,
  406. constants: Dict[str, Any],
  407. ) -> None:
  408. """
  409. This pass is run at the end of _export_non_strict() to assign better placeholder node names:
  410. - User inputs:
  411. These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y.
  412. For nested inputs from dictionaries, lists, tuples, or dataclasses,
  413. the names are a concatenation of the path to the tensor.
  414. e.g. x = {
  415. 'a': torch.randn(),
  416. 'b': [torch.randn(), torch.randn()]
  417. }
  418. produces nodes x_a, x_b_0, x_b_1.
  419. - Parameters/buffers/constants/custom objects:
  420. These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively.
  421. e.g. self.bar.l0.weight produces "p_bar_l0_weight".
  422. - Effect tokens:
  423. These are named token, token_1, ...
  424. """
  425. def _strip_name(x):
  426. if x.startswith("L__self___"):
  427. x = x[len("L__self___") :]
  428. x = re.sub(r"[^a-zA-Z0-9]", "_", x)
  429. return x
  430. def _extract_pytree_key(x):
  431. if isinstance(x, MappingKey):
  432. x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key))
  433. return x
  434. elif isinstance(x, SequenceKey):
  435. return str(x.idx)
  436. elif isinstance(x, GetAttrKey):
  437. return x.name
  438. else:
  439. raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}")
  440. name_map: Dict[str, str] = {}
  441. # map user input names with mod.forward() signature
  442. combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs)
  443. flat_args_with_path, _ = tree_flatten_with_path(combined_args)
  444. user_input_names = [
  445. spec.arg.name
  446. for spec in export_graph_signature.input_specs
  447. if spec.kind == InputKind.USER_INPUT
  448. ]
  449. # use pytree path to name nested user inputs
  450. for (arg_path, arg), user_input_name in zip(flat_args_with_path, user_input_names):
  451. if user_input_name:
  452. _rename_without_collisions(
  453. name_map,
  454. user_input_name,
  455. placeholder_prefixes[InputKind.USER_INPUT]
  456. + "_".join(_extract_pytree_key(x).lower() for x in arg_path),
  457. is_placeholder=True,
  458. )
  459. # use graph signature input specs to map param/buffer/constant names
  460. # name effect tokens as token, token_1, ... (these aren't visible to user)
  461. for spec in export_graph_signature.input_specs:
  462. if spec.kind == InputKind.USER_INPUT:
  463. continue
  464. if spec.kind == InputKind.TOKEN:
  465. base_name = ""
  466. else:
  467. base_name = _strip_name(spec.target).lower()
  468. base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name)
  469. _rename_without_collisions(
  470. name_map,
  471. spec.arg.name,
  472. placeholder_prefixes[spec.kind] + base_name,
  473. is_placeholder=True,
  474. )
  475. # handle naming collisions with call_function/get_attr inputs.
  476. # here, we want to prioritize user input names over call_function names
  477. # e.g. not have forward(self, mul): lead to a placeholder node called mul_13,
  478. # so we increment the suffix of call_function nodes as needed
  479. for node in gm.graph.nodes:
  480. if node.op == "placeholder":
  481. continue
  482. _rename_without_collisions(name_map, node.name, node.name)
  483. # assign new node names
  484. for node in gm.graph.nodes:
  485. if node.op == "placeholder":
  486. assert node.name in name_map
  487. node.name = node.target = name_map[node.name]
  488. elif node.name in name_map:
  489. node.name = name_map[node.name]
  490. # propagate names to higher order op subgraphs
  491. _name_hoo_subgraph_placeholders(gm)
  492. # re-generate graph module code
  493. gm.recompile()
  494. # modify graph signature (input specs, output specs, user input mutations)
  495. for spec in export_graph_signature.input_specs:
  496. assert spec.arg.name in name_map
  497. spec.arg.name = name_map[spec.arg.name]
  498. if ( # handle targets for custom objects
  499. spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
  500. ):
  501. spec.target = name_map[spec.target][4:] # strip obj_ prefix
  502. for spec in export_graph_signature.output_specs:
  503. if spec.arg.name in name_map:
  504. spec.arg.name = name_map[spec.arg.name]
  505. if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
  506. spec.target = name_map[spec.target]
  507. # rename keys in constants dict for custom objects
  508. for name in list(constants.keys()):
  509. constant = constants[name]
  510. if name in name_map and not isinstance(
  511. constant, torch.Tensor
  512. ): # rename custom objects with generic names
  513. new_name = name_map[name]
  514. if (
  515. new_name != name
  516. and re.match(r"arg(\d+)_1", name)
  517. and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name
  518. ):
  519. constants[new_name] = constant
  520. del constants[name]