cudagraphs.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # mypy: ignore-errors
  2. import functools
  3. import operator
  4. from collections import defaultdict
  5. from typing import Dict, List, Optional
  6. import torch
  7. from torch._dynamo import config
  8. from torch._dynamo.backends.common import aot_autograd
  9. from torch._dynamo.backends.debugging import boxed_nop
  10. from torch._inductor.cudagraph_utils import (
  11. BoxedDeviceIndex,
  12. check_multiple_devices_or_any_cpu_nodes,
  13. format_default_skip_message,
  14. get_mutation_stack_trace,
  15. get_placeholders,
  16. log_cudagraph_skip_and_bump_counter,
  17. )
  18. from torch._inductor.utils import (
  19. BoxedBool,
  20. count_tangents,
  21. get_first_incompatible_cudagraph_node,
  22. num_fw_fixed_arguments,
  23. output_node,
  24. )
  25. from torch.multiprocessing.reductions import StorageWeakRef
  26. from .registry import register_backend
  27. def find_input_mutations(g):
  28. def meta_fk(meta):
  29. return meta["val"] if "val" in meta else meta["fake_result"]
  30. inputs = defaultdict(set)
  31. input_idx = 0
  32. mutated_inputs = set()
  33. for n in g.nodes:
  34. if n.op == "placeholder":
  35. if isinstance(meta_fk(n.meta), torch.Tensor):
  36. inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
  37. input_idx += 1
  38. elif n.op == "call_function":
  39. if n.target is operator.getitem:
  40. continue
  41. schema = n.target._schema
  42. for i, arg in enumerate(schema.arguments):
  43. if i < len(n.args):
  44. argument = n.args[i]
  45. else:
  46. if arg.name not in n.kwargs:
  47. continue
  48. argument = n.kwargs[arg.name]
  49. mut_arg = False
  50. if arg.alias_info:
  51. if arg.alias_info.is_write:
  52. mut_arg = True
  53. if mut_arg:
  54. # TODO: not correct for args that contain tensors in a struct
  55. # like list
  56. mutated_inputs |= inputs[
  57. StorageWeakRef(meta_fk(argument.meta)._typed_storage())
  58. ]
  59. # TODO: error on unrecognized nodes
  60. return mutated_inputs
  61. def get_device_node_mapping(gm: torch.fx.GraphModule):
  62. device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
  63. for n in gm.graph.nodes:
  64. t = n.meta.get("val", None)
  65. if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
  66. device_node_mapping[t.device] = n
  67. return device_node_mapping
  68. def check_for_mutation_ignore_cuda_graph_managed_tensor(
  69. aot_model: torch.fx.GraphModule, num_fixed
  70. ) -> Optional[str]:
  71. mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
  72. if not mutation_indices:
  73. return None
  74. placeholders = [node for node in aot_model.graph.nodes if node.op == "placeholder"]
  75. return get_mutation_stack_trace(placeholders, mutation_indices)
  76. def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
  77. if not config.cudagraph_backend_support_input_mutation:
  78. if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
  79. aot_model, num_fixed
  80. ):
  81. return mut_skip
  82. if skip := check_multiple_devices_or_any_cpu_nodes(
  83. get_device_node_mapping(aot_model)
  84. ):
  85. return skip
  86. if node := get_first_incompatible_cudagraph_node(aot_model):
  87. return format_default_skip_message(f"incompatible op ({node.name})")
  88. return None
  89. def get_device_index(gm) -> int:
  90. device = next(iter(get_device_node_mapping(gm)))
  91. assert device.type == "cuda"
  92. return device.index
  93. def get_stack_traces(gm) -> List[Optional[str]]:
  94. output = output_node(gm)
  95. assert len(output.args) == 1
  96. return [
  97. (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
  98. for arg in output.args[0]
  99. ]
  100. def cudagraphs(dynamo_model, dynamo_inputs):
  101. from torch._inductor.cudagraph_trees import cudagraphify_impl
  102. do_cudagraphs = BoxedBool(True)
  103. boxed_device_index = BoxedDeviceIndex(None)
  104. def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
  105. interp = boxed_nop(aot_model, aot_inputs)
  106. fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
  107. if skip_msg := check_for_skip(aot_model, fixed):
  108. BoxedBool.disable(do_cudagraphs)
  109. log_cudagraph_skip_and_bump_counter(
  110. f"skipping cudagraphs due to {skip_msg}"
  111. )
  112. return interp
  113. boxed_device_index.set(get_device_index(aot_model))
  114. out = cudagraphify_impl(
  115. interp,
  116. aot_inputs,
  117. range(fixed),
  118. device_index=boxed_device_index.value,
  119. is_backward=False,
  120. is_inference=False,
  121. stack_traces=get_stack_traces(aot_model),
  122. placeholders=get_placeholders(aot_model.graph),
  123. mutated_input_idxs=find_input_mutations(aot_model.graph),
  124. )
  125. out._boxed_call = True
  126. return out
  127. def backward_cudagraphs(aot_model, aot_inputs):
  128. interp = boxed_nop(aot_model, aot_inputs)
  129. if not do_cudagraphs:
  130. return aot_model
  131. fixed = count_tangents(aot_model)
  132. if skip_msg := check_for_skip(aot_model, fixed):
  133. log_cudagraph_skip_and_bump_counter(
  134. "skipping cudagraphs due to %s", skip_msg
  135. )
  136. # See [Backward Generation Handling]
  137. manager = torch._inductor.cudagraph_trees.get_manager(
  138. boxed_device_index.value, create_if_none_exists=False
  139. )
  140. assert manager is not None
  141. def fn(inputs):
  142. manager.set_to_running_backward()
  143. return aot_model(inputs)
  144. fn._boxed_call = True
  145. return fn
  146. out = cudagraphify_impl(
  147. interp,
  148. aot_inputs,
  149. range(fixed),
  150. device_index=get_device_index(aot_model),
  151. is_backward=True,
  152. is_inference=False,
  153. stack_traces=get_stack_traces(aot_model),
  154. placeholders=get_placeholders(aot_model.graph),
  155. mutated_input_idxs=find_input_mutations(aot_model.graph),
  156. )
  157. out._boxed_call = True
  158. return out
  159. aot_cudagraphs = aot_autograd(
  160. fw_compiler=forward_cudagraphs,
  161. bw_compiler=backward_cudagraphs,
  162. inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
  163. keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
  164. )
  165. return aot_cudagraphs(dynamo_model, dynamo_inputs)
  166. class CudagraphsBackend:
  167. compiler_name = "cudagraphs"
  168. @staticmethod
  169. def reset():
  170. from torch._inductor.cudagraph_trees import reset_cudagraph_trees
  171. reset_cudagraph_trees()
  172. @staticmethod
  173. def __call__(model, inputs):
  174. return cudagraphs(model, inputs)
  175. # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
  176. # for debugging and can serve as a perf baseline.
  177. register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
  178. def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
  179. """This isn't registered as a backend, but is used in some benchmarks"""
  180. assert isinstance(inputs, (list, tuple))
  181. if copy_inputs:
  182. static_inputs = [torch.zeros_like(x) for x in inputs]
  183. else:
  184. static_inputs = list(inputs)
  185. # warmup
  186. torch.cuda.synchronize()
  187. stream = torch.cuda.Stream()
  188. stream.wait_stream(torch.cuda.current_stream())
  189. with torch.cuda.stream(stream):
  190. model(*inputs)
  191. stream.synchronize()
  192. torch.cuda.current_stream().wait_stream(stream)
  193. torch.cuda.synchronize()
  194. # record
  195. graph = torch.cuda.CUDAGraph()
  196. with torch.cuda.graph(graph, stream=stream):
  197. static_outputs = model(*static_inputs)
  198. if not isinstance(static_outputs, (list, tuple)):
  199. static_outputs = (static_outputs,)
  200. def run(*new_inputs):
  201. assert len(static_inputs) == len(new_inputs)
  202. if copy_inputs:
  203. for dst, src in zip(static_inputs, new_inputs):
  204. dst.copy_(src)
  205. graph.replay()
  206. if copy_outputs:
  207. return [x.clone() for x in static_outputs]
  208. else:
  209. return static_outputs
  210. return run