converter.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # mypy: allow-untyped-defs
  2. import operator
  3. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  4. import torch
  5. import torch.export._trace
  6. from torch.export.exported_program import ExportedProgram
  7. from torch.export.graph_signature import (
  8. ConstantArgument,
  9. InputKind,
  10. InputSpec,
  11. OutputKind,
  12. OutputSpec,
  13. TensorArgument,
  14. )
  15. from torch.fx import subgraph_rewriter
  16. from torch.onnx.utils import _create_jit_graph
  17. from torchgen.model import FunctionSchema
  18. def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
  19. def pattern(im, dim, scale):
  20. sym_size_int = torch.ops.aten.sym_size.int(im, dim)
  21. scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
  22. div_scalar_mode = torch.ops.aten.div.Scalar_mode(
  23. scalar_tensor, scale, rounding_mode="trunc"
  24. )
  25. int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
  26. return int_tensor
  27. def replacement(im, dim, scale):
  28. sym_size_int = torch.ops.aten.sym_size.int(im, dim)
  29. return sym_size_int // scale
  30. replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
  31. def normalize_name(name: str) -> str:
  32. return name.replace(".", "_")
  33. def ir_name_to_func_name(name: str) -> str:
  34. """prim::If -> convert_prim_If"""
  35. name_list = name.split("::")
  36. return "convert_" + "_".join(name_list)
  37. # Those operators will be automatically populated to a instance method
  38. # of TS2FXGraphConverter with name convert_<namespace>_<opname>().
  39. # Please check __init__ for method population implementations.
  40. kind_to_standard_operators = {
  41. "prim::TupleIndex": operator.getitem,
  42. "aten::__is__": operator.is_,
  43. "aten::__isnot__": operator.is_not,
  44. "aten::__not__": operator.not_,
  45. "aten::__contains__": operator.contains,
  46. }
  47. def get_op_overload(node: torch._C.Node):
  48. schema_str = node.schema()
  49. schema = FunctionSchema.parse(schema_str)
  50. ns, op_name = str(schema.name.name).split("::")
  51. override = schema.name.overload_name
  52. try:
  53. op_overload_mod = getattr(torch.ops, ns)
  54. op_overload_packet = getattr(op_overload_mod, op_name)
  55. if override:
  56. op_overload = getattr(op_overload_packet, override)
  57. else:
  58. op_overload = op_overload_packet.default
  59. except Exception as e:
  60. raise RuntimeError(
  61. f"Unable to find operator {node.kind()} with schema {node.schema}"
  62. ) from e
  63. return op_overload
  64. class TS2FXGraphConverter:
  65. def __init__(
  66. self,
  67. ts_graph: Union[torch._C.Graph, torch._C.Block],
  68. param_names: Set[str],
  69. buffer_names: Set[str],
  70. ):
  71. self.ts_graph = ts_graph
  72. self.param_names = param_names
  73. self.buffer_names = buffer_names
  74. self.fx_graph: torch.fx.Graph = torch.fx.Graph()
  75. self.input_specs: List[InputSpec] = []
  76. self.output_specs: List[OutputSpec] = []
  77. self.name_to_node: Dict[
  78. str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
  79. ] = {}
  80. self.constant_map: Dict[str, Any] = {}
  81. self.attribute_map: Dict[str, Any] = {}
  82. self.tensor_constants: Dict[str, torch.Tensor] = {}
  83. self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
  84. # Populate methods for the standard operators.
  85. for k in kind_to_standard_operators.keys():
  86. handler_func_name = ir_name_to_func_name(k)
  87. # Create an indirect function call:
  88. # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
  89. setattr(
  90. self,
  91. handler_func_name,
  92. lambda node: self._convert_standard_operators(node),
  93. )
  94. def add_subgraph(self, subgraph) -> str:
  95. name = f"subgraph_{len(self.subgraphs)}"
  96. self.subgraphs[name] = subgraph
  97. return name
  98. def get_args_kwargs(self, node: torch._C.Node, schema):
  99. args = []
  100. kwargs = {}
  101. for input, schema_arg in zip(node.inputs(), schema.arguments):
  102. if schema_arg.kwarg_only:
  103. kwargs[schema_arg.name] = self.get_fx_value(input)
  104. else:
  105. args.append(self.get_fx_value(input))
  106. return tuple(args), kwargs
  107. def get_fx_value(self, value: torch._C.Value):
  108. value_name = value.debugName()
  109. if value_name in self.name_to_node:
  110. input_node = self.name_to_node[value_name]
  111. return input_node
  112. elif value_name in self.attribute_map:
  113. attr_name = self.attribute_map[value_name]
  114. if attr_name in self.name_to_node:
  115. input_node = self.name_to_node[attr_name]
  116. return input_node
  117. else:
  118. raise ValueError(f"Value {attr_name} not found")
  119. elif value_name in self.constant_map:
  120. return self.constant_map[value_name]
  121. else:
  122. raise ValueError(f"Input {value_name} not found")
  123. def convert(self) -> torch.fx.GraphModule:
  124. self.convert_graph_inputs()
  125. for node in self.ts_graph.nodes():
  126. self.convert_node(node)
  127. self.convert_graph_outputs()
  128. gm = torch.fx.GraphModule(self.subgraphs, self.fx_graph)
  129. inplace_optimize_sym_size_div(gm)
  130. gm.graph.lint()
  131. return gm
  132. def convert_graph_inputs(self):
  133. for graph_input in self.ts_graph.inputs():
  134. name = graph_input.debugName()
  135. normalized_name = normalize_name(name)
  136. fx_node = self.fx_graph.placeholder(normalized_name)
  137. # fx_node.meta["val"] = FakeTensor()
  138. # TODO: set fx_node.meta["val"]
  139. self.name_to_node[name] = fx_node
  140. if name in self.param_names:
  141. self.input_specs.append(
  142. InputSpec(
  143. InputKind.PARAMETER,
  144. arg=TensorArgument(name=normalized_name),
  145. target=name,
  146. )
  147. )
  148. elif name in self.buffer_names:
  149. self.input_specs.append(
  150. InputSpec(
  151. InputKind.BUFFER,
  152. arg=TensorArgument(name=normalized_name),
  153. target=name,
  154. persistent=True,
  155. )
  156. )
  157. else:
  158. self.input_specs.append(
  159. InputSpec(
  160. InputKind.USER_INPUT,
  161. arg=TensorArgument(name=normalized_name),
  162. target=name,
  163. )
  164. )
  165. def convert_prim_Constant(self, node: torch._C.Node):
  166. name = node.output().debugName()
  167. value: Any = None
  168. if node.hasAttribute("value"):
  169. constant_kind = node.kindOf("value")
  170. if constant_kind == "i":
  171. value = node.i("value")
  172. elif constant_kind == "f":
  173. value = node.f("value")
  174. elif constant_kind == "s":
  175. value = node.s("value")
  176. elif constant_kind == "t":
  177. # lift tensor constant as a placeholder
  178. placeholder_name = f"constant_{name}"
  179. fx_node = self.fx_graph.placeholder(placeholder_name)
  180. self.name_to_node[name] = fx_node
  181. self.tensor_constants[placeholder_name] = node.t("value")
  182. self.input_specs.append(
  183. InputSpec(
  184. InputKind.CONSTANT_TENSOR,
  185. arg=TensorArgument(name=placeholder_name),
  186. target=placeholder_name,
  187. )
  188. )
  189. value = fx_node
  190. elif constant_kind == "ival":
  191. value = node.ival("value")
  192. else:
  193. raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
  194. else:
  195. value = None
  196. self.constant_map[name] = value
  197. def convert_prim_device(self, node: torch._C.Node):
  198. input_type = node.input().type()
  199. if input_type.isSubtypeOf(torch._C.TensorType.get()):
  200. device = input_type.device() # type: ignore[attr-defined]
  201. output_name = node.output().debugName()
  202. self.constant_map[output_name] = device
  203. else:
  204. raise ValueError(f"Unsupported JitType ({input_type}) when get device")
  205. def convert_prim_dtype(self, node: torch._C.Node):
  206. dtype = node.input().type().dtype()
  207. output_name = node.output().debugName()
  208. self.constant_map[output_name] = dtype
  209. def convert_prim_GetAttr(self, node: torch._C.Node):
  210. def get_attr(name: str):
  211. if name in self.attribute_map:
  212. return self.attribute_map[name]
  213. else:
  214. raise ValueError(f"Attribute {name} not found")
  215. output_name = node.output().debugName()
  216. attr_name = node.s("name")
  217. input_name = node.input().debugName()
  218. root_attr_name = get_attr(input_name)
  219. self.attribute_map[output_name] = (
  220. f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
  221. )
  222. def convert_call_function_op(self, node: torch._C.Node):
  223. target = get_op_overload(node)
  224. if target is torch.ops.aten.size.int:
  225. target = torch.ops.aten.sym_size.int
  226. args, kwargs = self.get_args_kwargs(node, target._schema)
  227. fx_node = self.fx_graph.call_function(target, args, kwargs)
  228. # TODO: covnert sourceRange() into stack_trace
  229. # fx_node.meta["stack_trace"] = node.sourceRange()
  230. output_name = node.output().debugName()
  231. self.name_to_node[output_name] = fx_node
  232. def convert_prim_TupleConstruct(self, node: torch._C.Node):
  233. self._convert_prim_iterator(node)
  234. def convert_prim_ListConstruct(self, node: torch._C.Node):
  235. self._convert_prim_iterator(node)
  236. def _convert_prim_iterator(self, node: torch._C.Node):
  237. output_list = []
  238. for inp in node.inputs():
  239. output_list.append(self.get_fx_value(inp))
  240. output_name = node.output().debugName()
  241. self.name_to_node[output_name] = output_list
  242. def convert_prim_DictConstruct(self, node: torch._C.Node):
  243. output_dict = {}
  244. k, v = None, None
  245. for i, inp in enumerate(node.inputs()):
  246. # We assume key value are stored in pair in the DictConstruct.
  247. # The first element is the key and the following is the value.
  248. if i % 2 == 0:
  249. k = self.get_fx_value(inp)
  250. else:
  251. v = self.get_fx_value(inp)
  252. assert (
  253. k is not None and v is not None
  254. ), "DictConstruct has an empty key value pair."
  255. output_dict[k] = v
  256. k, v = None, None
  257. assert (
  258. k is None and v is None
  259. ), "DictConstruct has an odd number of elements (violating our assumption)."
  260. output_name = node.output().debugName()
  261. self.name_to_node[output_name] = output_dict
  262. def convert_prim_ListUnpack(self, node: torch._C.Node):
  263. self._convert_prim_unpack_iterator(node)
  264. def convert_prim_TupleUnpack(self, node: torch._C.Node):
  265. self._convert_prim_unpack_iterator(node)
  266. def _convert_prim_unpack_iterator(self, node: torch._C.Node):
  267. # Single input and multiple outputs for unpacking.
  268. for i, outp in enumerate(node.outputs()):
  269. outp_name = outp.debugName()
  270. inp = self.get_fx_value(node.input())
  271. fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
  272. self.name_to_node[outp_name] = fx_node
  273. def convert_aten_Int(self, node: torch._C.Node):
  274. # converts aten::Int as aten._to_copy + aten::_local_scalar_dense
  275. target = torch.ops.aten._to_copy.default
  276. args = tuple(self.get_fx_value(input) for input in node.inputs())
  277. to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
  278. fx_node = self.fx_graph.call_function(
  279. torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
  280. )
  281. # TODO: covnert sourceRange() into stack_trace
  282. # fx_node.meta["stack_trace"] = node.sourceRange()
  283. output_name = node.output().debugName()
  284. self.name_to_node[output_name] = fx_node
  285. def convert_prim_NumToTensor(self, node: torch._C.Node):
  286. # converts prim::NumToTensor as aten.scalar_tensor
  287. target = torch.ops.aten.scalar_tensor
  288. args = tuple(self.get_fx_value(input) for input in node.inputs())
  289. fx_node = self.fx_graph.call_function(target, args)
  290. output_name = node.output().debugName()
  291. self.name_to_node[output_name] = fx_node
  292. def convert_prim_CreateObject(self, node: torch._C.Node):
  293. output_name = node.output().debugName()
  294. self.attribute_map[output_name] = ""
  295. def convert_aten__convolution(self, node: torch._C.Node):
  296. # converts aten::_convolution as aten.convolution, since aten::_convolution
  297. # doesn't have a meta function
  298. target = torch.ops.aten.convolution.default
  299. args, kwargs = self.get_args_kwargs(node, target._schema)
  300. fx_node = self.fx_graph.call_function(target, args, kwargs)
  301. output_name = node.output().debugName()
  302. self.name_to_node[output_name] = fx_node
  303. def convert_aten_div(self, node: torch._C.Node):
  304. target = get_op_overload(node)
  305. schema = target._schema
  306. args, kwargs = self.get_args_kwargs(node, schema)
  307. # converts aten::div.Tensor_mode(x, tensor_constant)
  308. # as aten.div.Scalar_mode(x, tensor_constant.item())
  309. if schema.overload_name == "Tensor_mode":
  310. arg1_name = args[1].name
  311. if arg1_name in self.tensor_constants:
  312. tensor_constant = self.tensor_constants[arg1_name]
  313. if tensor_constant.numel() == 1:
  314. updated_args = list(args)
  315. updated_args[1] = self.tensor_constants[arg1_name].item()
  316. fx_node = self.fx_graph.call_function(
  317. torch.ops.aten.div.Scalar_mode,
  318. tuple(updated_args),
  319. kwargs,
  320. )
  321. # TODO: covnert sourceRange() into stack_trace
  322. # fx_node.meta["stack_trace"] = node.sourceRange()
  323. output_name = node.output().debugName()
  324. self.name_to_node[output_name] = fx_node
  325. return
  326. self.convert_call_function_op(node)
  327. def convert_aten___getitem__(self, node: torch._C.Node):
  328. input_container, index = tuple(
  329. self.get_fx_value(input) for input in node.inputs()
  330. )
  331. fx_node = self.fx_graph.call_function(
  332. operator.getitem, (input_container, index)
  333. )
  334. output_name = node.output().debugName()
  335. self.name_to_node[output_name] = fx_node
  336. def convert_prim_If(self, node: torch._C.Node):
  337. inputs = list(node.inputs())
  338. assert len(inputs) == 1
  339. predicate = self.get_fx_value(inputs[0])
  340. # Get union of inputs to blocks
  341. arguments = set()
  342. for block in node.blocks():
  343. block_args = set()
  344. # TODO: block.inputs(), not sure what theyre used for
  345. for block_node in block.nodes():
  346. for block_node_in in block_node.inputs():
  347. if block_node_in.debugName() in self.name_to_node:
  348. block_args.add(block_node_in.debugName())
  349. arguments.update(block_args)
  350. arguments = list(arguments)
  351. # Convert blocks to subgraphs
  352. subgraph_nodes = []
  353. for block in node.blocks():
  354. subgraph_converter = TS2FXGraphConverter(block, set(), set())
  355. subgraph_converter.constant_map = self.constant_map
  356. for block_arg in arguments:
  357. normalized_block_arg_name = normalize_name(block_arg)
  358. placeholder_node = subgraph_converter.fx_graph.placeholder(
  359. normalized_block_arg_name
  360. )
  361. subgraph_converter.name_to_node[block_arg] = placeholder_node
  362. subgraph = subgraph_converter.convert()
  363. subgraph_name = self.add_subgraph(subgraph)
  364. subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
  365. assert len(subgraph_nodes) == 2
  366. fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments]
  367. args = (
  368. predicate,
  369. subgraph_nodes[0],
  370. subgraph_nodes[1],
  371. tuple(fx_block_args),
  372. )
  373. cond_node = self.fx_graph.call_function(torch.cond, args, {})
  374. output_name = node.output().debugName()
  375. self.name_to_node[output_name] = cond_node
  376. def convert_aten_Bool(self, node: torch._C.Node):
  377. self._convert_as_noop(node)
  378. def _convert_as_noop(self, node: torch._C.Node):
  379. # Converts the node as a no-op by mapping its output node as arg[0]
  380. target = get_op_overload(node)
  381. schema = target._schema
  382. args, kwargs = self.get_args_kwargs(node, schema)
  383. output_name = node.output().debugName()
  384. self.name_to_node[output_name] = args[0]
  385. def convert_profiler__record_function_enter_new(self, node: torch._C.Node):
  386. target = torch.ops.profiler._record_function_enter_new
  387. args = tuple(self.get_fx_value(input) for input in node.inputs())
  388. fx_node = self.fx_graph.call_function(target, args)
  389. output_name = node.output().debugName()
  390. self.name_to_node[output_name] = fx_node
  391. def convert_profiler__record_function_exit(self, node: torch._C.Node):
  392. # _record_function_exit has side effect so we keep it in fx.graph
  393. # currently, _record_function_enter_new and _record_function_exit are
  394. # discarded during `retrace_as_exported_program`.
  395. target = torch.ops.profiler._record_function_exit
  396. args = tuple(self.get_fx_value(input) for input in node.inputs())
  397. self.fx_graph.call_function(target, args)
  398. def _convert_standard_operators(self, node: torch._C.Node):
  399. target = kind_to_standard_operators[node.kind()]
  400. args = tuple(self.get_fx_value(input) for input in node.inputs())
  401. fx_node = self.fx_graph.call_function(target, args)
  402. output_name = node.output().debugName()
  403. self.name_to_node[output_name] = fx_node
  404. def convert_node(self, node: torch._C.Node):
  405. node_kind = node.kind()
  406. # Get handler based on namespace and operator name.
  407. # Provide a default node handler as well in case we don't find
  408. # matching converter for that.
  409. handler_func_name = ir_name_to_func_name(node_kind)
  410. handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
  411. handler_func(node)
  412. def convert_graph_outputs(self):
  413. args = []
  414. for graph_output in self.ts_graph.outputs():
  415. output_name = graph_output.debugName()
  416. if output_name in self.name_to_node:
  417. args.append(self.name_to_node[output_name])
  418. self.output_specs.append(
  419. OutputSpec(
  420. OutputKind.USER_OUTPUT,
  421. arg=TensorArgument(name=output_name),
  422. target=output_name,
  423. )
  424. )
  425. elif output_name in self.constant_map:
  426. args.append(self.constant_map[output_name])
  427. self.output_specs.append(
  428. OutputSpec(
  429. OutputKind.USER_OUTPUT,
  430. arg=ConstantArgument(
  431. name=output_name, value=self.constant_map[output_name]
  432. ),
  433. target=output_name,
  434. )
  435. )
  436. else:
  437. raise ValueError(f"Output {output_name} not found")
  438. self.fx_graph.output(
  439. args[0]
  440. ) # Get rid of an extra list wrapped around final output.
  441. class TS2EPConverter:
  442. # TorchScript model to ExportedProgram converter
  443. def __init__(
  444. self,
  445. ts_model,
  446. sample_args: Tuple[Any, ...],
  447. sample_kwargs: Optional[Dict[str, Any]] = None,
  448. ):
  449. self.ts_model = ts_model
  450. self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
  451. self.sample_args = sample_args
  452. self.sample_kwargs = sample_kwargs
  453. self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()}
  454. self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()}
  455. def convert(self) -> ExportedProgram:
  456. graph_converter = TS2FXGraphConverter(
  457. self.ts_graph, self.param_names, self.buffer_names
  458. )
  459. gm = graph_converter.convert()
  460. ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants)
  461. return ep
  462. def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants):
  463. # TODO: adjust input orders to match GraphSignature convention
  464. inputs = [*self.sample_args, *self.params, *tensor_constants.values()]
  465. ep = torch.export._trace._export(
  466. gm,
  467. tuple(inputs),
  468. strict=False,
  469. pre_dispatch=True,
  470. )
  471. return ep