comms.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # mypy: allow-untyped-defs
  2. # pyre-strict
  3. from typing import List
  4. import torch
  5. from . import config, ir, scheduler
  6. from .dependencies import WeakDep
  7. from .utils import is_collective, is_wait, tuple_sorted
  8. overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
  9. def sink_waits(
  10. snodes: List["scheduler.BaseSchedulerNode"],
  11. ) -> List["scheduler.BaseSchedulerNode"]:
  12. """
  13. Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of
  14. communication overlap.
  15. """
  16. new_order = []
  17. cur_waits = set()
  18. for snode in snodes:
  19. if is_wait(snode.node):
  20. cur_waits.add(snode)
  21. else:
  22. for wait in tuple_sorted(cur_waits):
  23. if snode in wait.node_users:
  24. new_order.append(wait)
  25. cur_waits.remove(wait)
  26. new_order.append(snode)
  27. new_order.extend(tuple_sorted(cur_waits))
  28. return new_order
  29. def raise_comms(
  30. snodes: List["scheduler.BaseSchedulerNode"],
  31. ) -> List["scheduler.BaseSchedulerNode"]:
  32. """
  33. Greedily moves comms as early as possible (i.e. until we reach an input).
  34. Optimal in terms of communication overlap.
  35. TODO: We might want to adjust this in the future to account for memory limitations.
  36. e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible,
  37. which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
  38. or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
  39. """
  40. new_order_reversed: List[scheduler.BaseSchedulerNode] = []
  41. cur_comms: List[scheduler.BaseSchedulerNode] = []
  42. for snode in reversed(snodes):
  43. if is_collective(snode.node):
  44. cur_comms.append(snode)
  45. else:
  46. for comm in cur_comms:
  47. assert len(comm.inverse_users) > 0
  48. while len(cur_comms) > 0 and any(
  49. snode in comm.inverse_users for comm in cur_comms
  50. ):
  51. comm = cur_comms.pop(0)
  52. new_order_reversed.append(comm)
  53. new_order_reversed.append(snode)
  54. assert len(cur_comms) <= 1
  55. new_order_reversed.extend(tuple_sorted(cur_comms))
  56. return new_order_reversed[::-1]
  57. def get_ancestors(node):
  58. ancestors = set()
  59. cur_nodes = [node]
  60. while len(cur_nodes) > 0:
  61. new_nodes = []
  62. for node in cur_nodes:
  63. for inp in node.inverse_users:
  64. if inp not in ancestors:
  65. ancestors.add(inp)
  66. new_nodes.append(inp)
  67. cur_nodes = new_nodes
  68. return ancestors
  69. def get_descendants(node):
  70. descendants = set()
  71. cur_nodes = [node]
  72. while len(cur_nodes) > 0:
  73. new_nodes = []
  74. for node in cur_nodes:
  75. for inp in node.node_users:
  76. if inp not in descendants:
  77. descendants.add(inp)
  78. new_nodes.append(inp)
  79. cur_nodes = new_nodes
  80. return descendants
  81. def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]):
  82. """
  83. Decide global ordering of comms, by just enforcing the ordering that's in the input graph
  84. (might not be the same ordering as the eager mode program).
  85. TODO: Come up with a better approach
  86. """
  87. comm_nodes = [n for n in nodes if is_collective(n.node)]
  88. for i in range(1, len(comm_nodes)):
  89. # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
  90. comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name()))
  91. def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None:
  92. assert not any(is_collective(snode.node) for snode in snodes)
  93. def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float:
  94. """
  95. Returns estimated op runtime in nanoseconds (ns)
  96. """
  97. if config.estimate_op_runtime == "default":
  98. runtime = snode.get_estimated_runtime()
  99. else:
  100. assert callable(config.estimate_op_runtime)
  101. runtime = config.estimate_op_runtime(snode)
  102. return runtime
  103. def reorder_compute_for_overlap(
  104. snodes: List["scheduler.BaseSchedulerNode"],
  105. ) -> List["scheduler.BaseSchedulerNode"]:
  106. """
  107. Decides a global ordering of all compute and communication nodes,
  108. assuming that we already have a global ordering of communication nodes.
  109. Overall scheduling procedure is:
  110. Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
  111. that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
  112. Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
  113. Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
  114. We prioritize compute nodes that are needed sooner.
  115. Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
  116. Step 4: We schedule comm N + 1.
  117. Repeat this for subsequent comm nodes.
  118. """
  119. final_order = []
  120. comm_nodes = []
  121. for snode in snodes:
  122. if is_collective(snode.node):
  123. comm_nodes.append(snode)
  124. if len(comm_nodes) == 0:
  125. # if there is no comm nodes, return the current order
  126. return snodes
  127. comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
  128. comm_descendants = {node: get_descendants(node) for node in comm_nodes}
  129. indeg = dict.fromkeys(snodes, 0)
  130. for snode in snodes:
  131. for user in snode.node_users:
  132. if user in indeg:
  133. indeg[user] += 1
  134. ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0}
  135. unscheduled_nodes = set()
  136. unscheduled_nodes = set(snodes)
  137. def schedule_node(snode):
  138. """
  139. Schedule a single node.
  140. """
  141. assert snode in unscheduled_nodes
  142. assert snode in ready_to_schedule_nodes
  143. ready_to_schedule_nodes.remove(snode)
  144. unscheduled_nodes.remove(snode)
  145. final_order.append(snode)
  146. for user in tuple_sorted(snode.node_users):
  147. if user in indeg:
  148. indeg[user] -= 1
  149. if indeg[user] == 0:
  150. ready_to_schedule_nodes.add(user)
  151. def schedule_nodes(snodes):
  152. """
  153. Schedules all nodes in `snodes` in an arbitrary topologically valid order.
  154. """
  155. all_nodes = set(snodes)
  156. assert all(node in unscheduled_nodes for node in all_nodes)
  157. while len(all_nodes) > 0:
  158. # NOTE: since model graph is always a DAG and does not have circular dependency inside,
  159. # there should be at least one node that is a "free node" (i.e. indeg == 0),
  160. # hence infinite loop is not possible. But we check here just to be safe.
  161. progress = False
  162. for node in tuple_sorted(all_nodes):
  163. if node in ready_to_schedule_nodes:
  164. schedule_node(node)
  165. all_nodes.remove(node)
  166. progress = True
  167. if not progress:
  168. raise AssertionError(
  169. "Unable to find a free node (indeg == 0). This is an impossible state to reach. "
  170. "Please report a bug to PyTorch."
  171. )
  172. # First, schedule all compute nodes that are required by first comm node,
  173. # as well as the first comm node itself.
  174. assert len(comm_nodes) > 0
  175. schedule_nodes(
  176. list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]],
  177. )
  178. rolled_over_compute_cost = 0
  179. for idx in range(1, len(comm_ancestors)):
  180. # Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule
  181. # all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`,
  182. # to run at the same time with comm `idx-1`.
  183. needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & (
  184. comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]]
  185. )
  186. assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)
  187. total_compute_runtime_cost = rolled_over_compute_cost + sum(
  188. estimate_op_runtime(node)
  189. for node in needed_by_next_comm_and_ready_compute_nodes
  190. )
  191. prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
  192. schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
  193. # Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done.
  194. # Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`.
  195. # We prioritize compute nodes that are needed sooner.
  196. step1_runtime_cost = total_compute_runtime_cost
  197. if step1_runtime_cost >= prev_comm_runtime_cost:
  198. pass
  199. else:
  200. # Find all ready to schedule compute nodes that do not depend on comm `idx-1`.
  201. ready_to_schedule_compute_nodes = tuple_sorted(
  202. ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]]
  203. )
  204. assert_no_comm_nodes(ready_to_schedule_compute_nodes)
  205. def earliest_comm_descendant(node):
  206. for idx in range(len(comm_nodes)):
  207. if node in comm_ancestors[comm_nodes[idx]]:
  208. return idx
  209. return len(comm_nodes)
  210. # Prioritize compute nodes that are needed sooner.
  211. ready_to_schedule_compute_nodes = sorted(
  212. ready_to_schedule_compute_nodes, key=earliest_comm_descendant
  213. )
  214. for snode in ready_to_schedule_compute_nodes:
  215. if total_compute_runtime_cost >= prev_comm_runtime_cost:
  216. # If accumulated compute runtime cost is greater than comm `idx-1` runtime cost,
  217. # it means we have maximized overlap for comm `idx-1`, and hence we stop looking
  218. # for more compute to schedule.
  219. break
  220. compute_runtime_cost = estimate_op_runtime(snode)
  221. # If we're not able to leverage more than half of this
  222. # node's compute to overlap, we skip it.
  223. # TODO: Smarter heuristics here
  224. if (
  225. prev_comm_runtime_cost - total_compute_runtime_cost
  226. ) <= compute_runtime_cost / 2:
  227. continue
  228. schedule_node(snode)
  229. total_compute_runtime_cost += compute_runtime_cost
  230. rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost
  231. # Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`.
  232. needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]]
  233. schedule_nodes(list(needed_by_next_comm_nodes))
  234. # Step 4: We schedule comm `idx`.
  235. schedule_nodes([comm_nodes[idx]])
  236. is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0
  237. # The idea here is that if there are no compute nodes from Step 3
  238. # (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes
  239. # in Step 2 to overlap with the next comm, since they're not required to finish
  240. # before the next comm starts.
  241. if is_prev_comm_blocking_next_comm:
  242. rolled_over_compute_cost = 0
  243. else:
  244. rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment]
  245. schedule_nodes(unscheduled_nodes)
  246. return final_order
  247. def node_summary(snode):
  248. detail = ""
  249. if isinstance(snode.node, ir.ExternKernelOut):
  250. detail = f" ({snode.node.python_kernel_name})"
  251. out_tensor_info = ""
  252. if (
  253. hasattr(snode.node, "layout")
  254. and hasattr(snode.node.layout, "size")
  255. and hasattr(snode.node.layout, "stride")
  256. ):
  257. out_tensor_info = (
  258. f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
  259. )
  260. node_name = ""
  261. if hasattr(snode.node, "name"):
  262. node_name = snode.node.name
  263. return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
  264. def visualize_overlap(order):
  265. total_est_runtime: float = 0.0
  266. cur_comm_node = None
  267. for snode in order:
  268. if cur_comm_node is None:
  269. if is_collective(snode.node):
  270. total_est_runtime += estimate_op_runtime(snode)
  271. cur_comm_node = snode.node
  272. elif is_wait(snode.node):
  273. raise AssertionError(
  274. "Wait is not expected when there is no collective running"
  275. )
  276. else: # exposed compute op
  277. total_est_runtime += estimate_op_runtime(snode)
  278. overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
  279. else: # cur_comm_node is not None
  280. if is_collective(snode.node):
  281. raise AssertionError(
  282. "Found two collectives running at the same time. "
  283. "`visualize_overlap` needs to be updated to handle this case"
  284. )
  285. elif is_wait(snode.node): # end of this comm op
  286. overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
  287. cur_comm_node = None
  288. else: # overlapped compute op
  289. overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
  290. overlap_log.debug(
  291. f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
  292. )
  293. def reorder_compute_and_comm_for_overlap(
  294. snodes: List["scheduler.BaseSchedulerNode"],
  295. ) -> List["scheduler.BaseSchedulerNode"]:
  296. order = snodes
  297. for p in config.reorder_for_compute_comm_overlap_passes:
  298. if isinstance(p, str) and p in globals():
  299. p = globals()[p] # it is a builtin pass
  300. if torch.distributed.get_rank() == 0:
  301. overlap_log.debug(
  302. f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
  303. )
  304. try:
  305. visualize_overlap(order)
  306. except Exception as e:
  307. overlap_log.debug(str(e))
  308. order = p(order) # type: ignore[operator]
  309. if torch.distributed.get_rank() == 0:
  310. overlap_log.debug(
  311. f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
  312. )
  313. try:
  314. visualize_overlap(order)
  315. except Exception as e:
  316. overlap_log.debug(str(e))
  317. return order