cudagraph_utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. from typing import Any, Callable, Dict, List, Optional, Tuple
  4. import torch
  5. from torch._dynamo.utils import counters
  6. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  7. @dataclasses.dataclass(frozen=True)
  8. class FunctionID:
  9. "Unique counter of a function wrapped in cudagraphify_impl"
  10. id: int
  11. @dataclasses.dataclass(frozen=True)
  12. class WrappedFunction:
  13. """
  14. Represents a function that you want to record for CUDA graph replay,
  15. with a little more metadata so we can identify if we have an applicable
  16. CUDA graph in our CUDA graph tree for it.
  17. """
  18. model: Callable[..., Any]
  19. static_input_idxs: List[int]
  20. id: FunctionID
  21. constants: Tuple[torch.Tensor, ...]
  22. placeholders: List[torch.fx.Node]
  23. mutated_input_idxs: List[int]
  24. def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
  25. return [node for node in graph.nodes if node.op == "placeholder"]
  26. def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]:
  27. # reinplaced uses might have a single, non-copy_ use
  28. if len(placeholder_node.users) == 1:
  29. return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
  30. for use in placeholder_node.users:
  31. if use.target == torch.ops.aten.copy_.default:
  32. if stack_trace := use.meta.get("stack_trace", None):
  33. return stack_trace
  34. return None
  35. def format_default_skip_message(reason: str) -> str:
  36. return f"skipping cudagraphs due to {reason}"
  37. def get_mutation_stack_trace(
  38. placeholders: List[torch.fx.Node], mutation_indices: List[int]
  39. ) -> str:
  40. stack_trace: Optional[str] = ""
  41. for idx in mutation_indices:
  42. placeholder = placeholders[idx]
  43. if stack_trace := get_mutating_use_stack_trace(placeholder):
  44. break
  45. msg = format_default_skip_message(
  46. f"mutated inputs ({len(mutation_indices)} instances)"
  47. )
  48. if stack_trace:
  49. return f"{msg}. Found from : \n {stack_trace}"
  50. return msg
  51. def check_for_mutation(
  52. func: WrappedFunction,
  53. inputs: List[torch.Tensor],
  54. is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
  55. ) -> Optional[str]:
  56. # doesnt work for non-trees because the warmup run would apply mutation twice
  57. if torch._inductor.config.triton.cudagraph_trees:
  58. # checking if mutation is only on parameters/static inputs
  59. mutation_indices = [
  60. idx
  61. for idx in func.mutated_input_idxs
  62. if not (
  63. idx in func.static_input_idxs
  64. or is_cuda_graph_recorded_tensor(inputs[idx])
  65. )
  66. ]
  67. else:
  68. mutation_indices = func.mutated_input_idxs
  69. return (
  70. get_mutation_stack_trace(func.placeholders, mutation_indices)
  71. if mutation_indices
  72. else None
  73. )
  74. def get_use_stack_trace(node) -> Optional[str]:
  75. for use in node.users:
  76. if stack_trace := use.meta.get("stack_trace", None):
  77. return stack_trace
  78. return None
  79. def check_multiple_devices_or_any_cpu_nodes(
  80. device_node_mapping: Dict[torch.device, torch.fx.Node]
  81. ) -> Optional[str]:
  82. if cpu_node := device_node_mapping.get(torch.device("cpu")):
  83. msg = f"cpu device ({cpu_node.name})"
  84. if stack_trace := get_use_stack_trace(cpu_node):
  85. return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
  86. return format_default_skip_message(msg)
  87. if (
  88. len(device_node_mapping) == 1
  89. and next(iter(device_node_mapping.keys())).type == "cuda"
  90. ):
  91. return None
  92. keys_repr = (repr(key) for key in device_node_mapping.keys())
  93. return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
  94. def check_lowering_disable_cudagraph(
  95. device_node_mapping: Dict[torch.device, torch.fx.Node]
  96. ):
  97. return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
  98. def log_cudagraph_skip_and_bump_counter(msg):
  99. perf_hint_log.warning(msg)
  100. counters["inductor"]["cudagraph_skips"] += 1
  101. @dataclasses.dataclass
  102. class BoxedDeviceIndex:
  103. value: Optional[int]
  104. def set(self, device_idx: Optional[int]):
  105. assert device_idx is None or isinstance(device_idx, int)
  106. self.value = device_idx
  107. def check_for_mutation_ignore_cuda_graph_managed_tensor(
  108. gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: List[int]
  109. ) -> Optional[str]:
  110. default_msg = format_default_skip_message("mutated inputs")
  111. # doesnt work for non-trees because the warmup run would apply mutation twice
  112. if torch._inductor.config.triton.cudagraph_trees:
  113. unique_idxs = set(static_input_idxs)
  114. # checking if mutation is only on parameters/static inputs
  115. mutation_indices = [
  116. idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
  117. ]
  118. has_mutation = len(mutation_indices) != 0
  119. if not has_mutation:
  120. return None
  121. placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
  122. return get_mutation_stack_trace(placeholders, mutation_indices)
  123. else:
  124. has_mutation = len(compiled_graph.mutated_inputs) != 0
  125. return None if not has_mutation else default_msg
  126. def get_placeholder_stack_trace(placeholder: torch.fx.Node) -> Optional[str]:
  127. """
  128. Gets the first non-empty stack trace of a placeholder or its users.
  129. """
  130. if placeholder.stack_trace:
  131. return placeholder.stack_trace
  132. for user in placeholder.users:
  133. if user.stack_trace:
  134. return user.stack_trace
  135. return None