comm_analysis.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import functools
  2. import math
  3. from enum import IntEnum
  4. import sympy
  5. import torch
  6. from . import ir
  7. from .utils import get_dtype_size, sympy_product
  8. from .virtualized import V
  9. class NCCL_COLL(IntEnum):
  10. ALL_REDUCE = 0
  11. ALL_GATHER = 1
  12. REDUCE_SCATTER = 2
  13. class NVIDIA_GPU_TYPE(IntEnum):
  14. VOLTA = 0
  15. AMPERE = 1
  16. HOPPER = 2
  17. @functools.lru_cache
  18. def get_gpu_type() -> NVIDIA_GPU_TYPE:
  19. gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
  20. if "V100" in gpu_info:
  21. return NVIDIA_GPU_TYPE.VOLTA
  22. elif "A100" in gpu_info:
  23. return NVIDIA_GPU_TYPE.AMPERE
  24. elif "H100" in gpu_info:
  25. return NVIDIA_GPU_TYPE.HOPPER
  26. else:
  27. # for other gpu types, assume Ampere
  28. return NVIDIA_GPU_TYPE.AMPERE
  29. def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
  30. if not isinstance(node, ir._CollectiveKernel):
  31. raise ValueError(f"node is not a collective kernel: {node}")
  32. kernel_name = node.python_kernel_name
  33. assert kernel_name is not None
  34. if "all_reduce" in kernel_name:
  35. return NCCL_COLL.ALL_REDUCE
  36. elif "all_gather" in kernel_name:
  37. return NCCL_COLL.ALL_GATHER
  38. elif "reduce_scatter" in kernel_name:
  39. return NCCL_COLL.REDUCE_SCATTER
  40. else:
  41. raise ValueError(f"Unsupported collective kernel: {kernel_name}")
  42. def get_collective_input_size_bytes(node: ir.IRNode) -> int:
  43. sz_bytes = 0
  44. for inp in node.inputs: # type: ignore[attr-defined]
  45. numel = sympy_product(inp.layout.size)
  46. if isinstance(numel, sympy.Integer):
  47. # For ease of testing
  48. numel = int(numel)
  49. else:
  50. numel = V.graph.sizevars.size_hint(numel, fallback=0)
  51. sz_bytes += numel * get_dtype_size(inp.layout.dtype)
  52. return sz_bytes
  53. def get_collective_group_size(node: ir.IRNode) -> int:
  54. if type(node) == ir._CollectiveKernel:
  55. from torch.distributed.distributed_c10d import _get_group_size_by_name
  56. return _get_group_size_by_name(node.constant_args[-1])
  57. else:
  58. raise TypeError(f"Unsupported collective type: {node}")
  59. ####################################################################################################################
  60. # The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
  61. ####################################################################################################################
  62. class NCCL_HW(IntEnum):
  63. NVLINK = 0
  64. PCI = 1
  65. NET = 2
  66. class NCCL_ALGO(IntEnum):
  67. TREE = 0
  68. RING = 1
  69. class NCCL_PROTO(IntEnum):
  70. # The ordering and enum values here matches original in
  71. # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
  72. # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
  73. LL = 0 # Low-latency
  74. # LL128 = 1 # Low-latency 128-byte
  75. # SIMPLE = 2
  76. # Latencies in us
  77. # len(NCCL_ALGO) x len(NCCL_PROTO)
  78. # NOTE: use array instead of tensor to prevent incompatibility with fake mode
  79. baseLat = [
  80. # Tree
  81. [
  82. 6.8, # LL
  83. ],
  84. # Ring
  85. [
  86. 6.6, # LL
  87. ],
  88. ]
  89. # Latencies in us
  90. # len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
  91. hwLat = [
  92. # NVLINK
  93. [
  94. [0.6], # Tree (LL)
  95. [0.6], # Ring (LL)
  96. ],
  97. # PCI
  98. [
  99. [1.0], # Tree (LL)
  100. [1.0], # Ring (LL)
  101. ],
  102. # NET
  103. [
  104. [5.0], # Tree (LL)
  105. [2.7], # Ring (LL)
  106. ],
  107. ]
  108. # LL128 max BW per channel
  109. llMaxBws = [
  110. # Volta-N1/Intel-N2/Intel-N4
  111. [
  112. 39.0,
  113. 39.0,
  114. 20.4,
  115. ],
  116. # Ampere-N1/AMD-N2/AMD-N4
  117. [
  118. 87.7,
  119. 22.5, # avg of ring & tree
  120. 19.0,
  121. ],
  122. # Hopper-N1/AMD-N2/AMD-N4
  123. [
  124. 87.7,
  125. 22.5, # avg of ring & tree
  126. 19.0,
  127. ],
  128. ]
  129. def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
  130. """
  131. Returns estimated NCCL collective runtime in nanoseconds (ns).
  132. The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
  133. We aim to estimate the runtime as accurately as possible.
  134. Assumptions:
  135. - only ring algorithm (NCCL_ALGO_RING) is used
  136. - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
  137. - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
  138. - collective is one of: allreduce, reducescatter, allgather
  139. """
  140. tensor_storage_size_bytes = get_collective_input_size_bytes(node)
  141. # Convert bytes to GB
  142. tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
  143. # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
  144. # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
  145. num_gpus_per_node = 8
  146. group_size = get_collective_group_size(node)
  147. nNodes = math.ceil(group_size / num_gpus_per_node)
  148. nRanks = group_size # this is total # of gpus globally that participate in this collective op
  149. if nRanks <= 1:
  150. return 0
  151. # Assumes ring algorithm
  152. nccl_algo = NCCL_ALGO.RING
  153. nccl_proto = NCCL_PROTO.LL
  154. coll = get_collective_type(node)
  155. # =============== bandwidth computation ===============
  156. # First compute bandwidth in GB/s; then at the end, convert it to GB/ns
  157. bwIntra = torch._inductor.config.intra_node_bw
  158. bwInter = torch._inductor.config.inter_node_bw
  159. compCapIndex = get_gpu_type()
  160. index2 = nNodes - 1 if nNodes <= 2 else 2
  161. # LL: for single node, we look at GPU type; for multi-node, we look at CPU type
  162. index1 = compCapIndex if nNodes == 1 else 0
  163. llMaxBw = llMaxBws[index1][index2]
  164. # NOTE: each step of ring algorithm is synchronized,
  165. # and is bottlenecked by the slowest link which is the inter-node interconnect.
  166. # hence when nNodes >= 2, bw is inter-node bandwidth.
  167. # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
  168. # have this as `if nNodes <= 2` which seems wrong. Corrected it here.
  169. bw = bwIntra if nNodes == 1 else bwInter
  170. nChannels = 2 # Assume # channels is 2
  171. busBw = nChannels * bw
  172. # Various model refinements
  173. busBw = min(
  174. llMaxBw,
  175. busBw
  176. * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
  177. )
  178. if coll == NCCL_COLL.ALL_REDUCE:
  179. nsteps = 2 * (nRanks - 1)
  180. elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
  181. nsteps = nRanks - 1
  182. # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
  183. ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
  184. bandwidth = busBw * ratio
  185. # Convert GB/s to GB/ns
  186. bandwidth_GB_per_ns = bandwidth / 1e9
  187. # =============== latency computation ===============
  188. intraHw = NCCL_HW.NVLINK
  189. if coll == NCCL_COLL.ALL_REDUCE:
  190. if nNodes > 1:
  191. nInterSteps = 2 * nNodes
  192. else:
  193. nInterSteps = 0
  194. elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
  195. nInterSteps = nNodes - 1
  196. # First compute latency in us; then at the end, convert it to ns
  197. latency = baseLat[nccl_algo][nccl_proto]
  198. intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
  199. interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
  200. # Inter-node rings still have to launch nsteps * net overhead.
  201. netOverhead = 0.0
  202. if nNodes > 1:
  203. netOverhead = 1.0 # getNetOverhead(comm);
  204. intraLat = max(intraLat, netOverhead)
  205. latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
  206. # Convert us to ns
  207. latency_ns = latency * 1e3
  208. # =============== final result ===============
  209. transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
  210. return transport_ns + latency_ns
  211. ################################################################################################################
  212. # The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
  213. ################################################################################################################