extract_compiled_graph.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import dataclasses
  4. import itertools
  5. import os
  6. from typing import Any, Callable, Dict, List
  7. import torch
  8. import torch._lazy as lazy
  9. import torch._lazy.metrics as metrics
  10. from torch import fx
  11. from torch._lazy import computation, debug as lazy_debug
  12. from torch._lazy.tensor_factory_functions import tensor_factory_functions
  13. debug = os.environ.get("debug_extract_compiled_graph") is not None
  14. @dataclasses.dataclass
  15. class GraphInputMatcher:
  16. """
  17. The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing.
  18. Specifically, those graph inputs corresponding to method parameters should be replaced with the
  19. arguments for the current call.
  20. tensor_id_to_arg_idx maps the tensor id to the parameter index.
  21. graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the
  22. TS/XLA graph inputs.
  23. """
  24. tensor_id_to_arg_idx: Dict[int, int]
  25. graph_input_tensor_ids: List[int]
  26. # there are 2 categories of graph_input_tensors.
  27. # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
  28. # most likely const tensors and we can get its content from graph_input_tensors
  29. # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
  30. # the tensor from method arguments
  31. graph_input_ivalues: List[Any]
  32. # get the real graph input tensors
  33. def __call__(self, args):
  34. real_input = []
  35. for tensor_id, traced_ivalue in zip(
  36. self.graph_input_tensor_ids, self.graph_input_ivalues
  37. ):
  38. arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
  39. if arg_idx is None:
  40. inp = traced_ivalue
  41. else:
  42. inp = args[arg_idx]
  43. real_input.append(inp)
  44. return real_input
  45. class ReturnValueHandler:
  46. r"""
  47. When ltc_sync_multi is called on multi tensors, the compiled graph
  48. will contain output only for unique tensors - if a tensor appears multiple
  49. times in the input to _ltc_sync_multi, only the first occurance matters.
  50. However from python level, we still expect multi tensors returned with duplciation
  51. even if the TS graph dedup the output. e.g. for method:
  52. def forward(self, a):
  53. return a, a
  54. the TS graph captured by LTC will return a single tensor, but Python method expects 2.
  55. This class dedup the lazy tensors first to get the index that will be used
  56. to duplicate the eager tensors later.
  57. """
  58. def __init__(self, lazy_out_list):
  59. self.index: List[List[int]] = []
  60. self.total_count = len(lazy_out_list)
  61. tensor_id_to_idx: Dict[int, int] = {}
  62. for dup_idx, lazy_tensor in enumerate(lazy_out_list):
  63. uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
  64. if uniq_idx is not None:
  65. self.index[uniq_idx].append(dup_idx)
  66. else:
  67. uniq_idx = len(self.index)
  68. self.index.append([dup_idx])
  69. tensor_id_to_idx[id(lazy_tensor)] = uniq_idx
  70. def duplicate_eager_tensors(self, eager_tensor_list):
  71. duplicated_list = [None] * self.total_count
  72. assert len(eager_tensor_list) == len(self.index)
  73. for uniq_idx, eager_tensor in enumerate(eager_tensor_list):
  74. for dup_idx in self.index[uniq_idx]:
  75. duplicated_list[dup_idx] = eager_tensor
  76. return duplicated_list
  77. def force_lazy_device(model: fx.GraphModule):
  78. """
  79. Factory methods in a Fx graph may create tensors for a specific eager devices.
  80. If we take no actions, those eager tensors will be mixed with lazy tensors and
  81. cause crash. This method overwrite those eager device to lazy device.
  82. """
  83. def tolazydevice(dev):
  84. if isinstance(dev, torch.device):
  85. return torch.device("lazy", index=dev.index)
  86. return dev
  87. def hasDeviceArg(args, kwargs):
  88. return any(
  89. isinstance(arg, torch.device)
  90. for arg in itertools.chain(args, kwargs.values())
  91. )
  92. for nd in model.graph.nodes:
  93. nd.args = tuple(tolazydevice(arg) for arg in nd.args)
  94. nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}
  95. # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
  96. # eager tensors on the default device
  97. # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
  98. # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
  99. # To force those tensors on the lazy device, we can not simply override
  100. # the device argument since there is no explicit device argument.
  101. # What we are doing here is, for the list of covered tensor factory methods
  102. # we add a lazy device argument explicity.
  103. #
  104. # TODO: This solution is no ideal since we may miss some factory methods. In future
  105. # when we support lazy mode, this method can be replaced by that.
  106. if nd.target in tensor_factory_functions and not hasDeviceArg(
  107. nd.args, nd.kwargs
  108. ):
  109. kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy.
  110. kwargs["device"] = torch.device("lazy")
  111. nd.kwargs = kwargs
  112. model.recompile()
  113. def get_fallback_ops():
  114. fallback_ops = []
  115. for opname in metrics.counter_names():
  116. if "aten::" not in opname:
  117. continue
  118. val = int(metrics.counter_value(opname))
  119. if val > 0:
  120. fallback_ops.append(f"{opname}={val}")
  121. return fallback_ops
  122. def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable:
  123. """
  124. Optimize an eager model with LTC and returns a wrapper to execute the
  125. compiled graph directly without retracing. It depends on other mechanisms
  126. like TorchDynamo guards to guarantee the returned wrapper is only called
  127. when it's safe.
  128. """
  129. lazy_args = [arg.to(device="lazy") for arg in example_inputs]
  130. args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args]
  131. tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
  132. lazy_model = copy.deepcopy(model).to(device=torch.device("lazy"))
  133. force_lazy_device(lazy_model)
  134. # This line executes lazy tracing and enable us extracting compiled graph later
  135. metrics.reset()
  136. lazy_out = lazy_model(*lazy_args)
  137. fallback_ops = get_fallback_ops()
  138. metrics.reset()
  139. if len(fallback_ops) > 0:
  140. raise RuntimeError(
  141. f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
  142. )
  143. if not isinstance(lazy_out, (tuple, list)):
  144. lazy_out = (lazy_out,)
  145. args_and_out = tuple(lazy_args) + tuple(lazy_out)
  146. return_value_handler = ReturnValueHandler(args_and_out)
  147. if debug:
  148. print("Fx code:\n", model.code)
  149. print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text"))
  150. # TODO: this part is TS backend specific for now and will be generalized to
  151. # support XLA
  152. (
  153. graph_input_tensor_ids,
  154. graph_input_ivalues,
  155. ) = computation.get_tensors_ts_device_data_node(args_and_out)
  156. assert len(graph_input_tensor_ids) == len(graph_input_ivalues)
  157. graph_input_matcher = GraphInputMatcher(
  158. tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
  159. )
  160. graph_hash = computation.get_graph_hash(args_and_out)
  161. if debug:
  162. print("graph_hash", graph_hash)
  163. print(f"args_tensor_ids {args_tensor_ids}")
  164. print("tensor ids from device data:", graph_input_tensor_ids)
  165. # sync the list of output tensors so the computation graph for these
  166. # tensors will be cached. Those computation graphs can be retrieved
  167. # by graph hash later.
  168. lazy.sync_multi(args_and_out, [])
  169. def optimized_mod(*args):
  170. if len(args_and_out) == 0:
  171. return ()
  172. graph_input = graph_input_matcher(args)
  173. res = return_value_handler.duplicate_eager_tensors(
  174. computation.run_cached_graph(graph_hash, graph_input)
  175. )
  176. assert len(res) == len(args_and_out)
  177. for i, arg in enumerate(args):
  178. # only copy those tensors that get inplace updated
  179. if arg is not res[i]:
  180. arg.copy_(res[i])
  181. # skip the args
  182. return res[len(args) :]
  183. return optimized_mod