wrapper_benchmark.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import tempfile
  4. from collections import defaultdict
  5. import torch
  6. from torch.autograd import DeviceType
  7. from .runtime.runtime_utils import (
  8. create_bandwidth_info_str,
  9. do_bench_gpu,
  10. get_num_bytes,
  11. )
  12. _kernel_category_choices = [
  13. "foreach",
  14. "persistent_reduction",
  15. "pointwise",
  16. "reduction",
  17. "split_scan",
  18. "template",
  19. ]
  20. def get_kernel_category_by_source_code(src_code):
  21. """
  22. Similar to get_kernel_category but use the source code. Call this API
  23. if we have not compile the src_code to module yet.
  24. """
  25. choices = [
  26. ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
  27. ]
  28. if len(choices) == 1:
  29. return choices[0]
  30. else:
  31. return "unknown"
  32. def get_kernel_category(kernel_mod):
  33. """
  34. Given the module defining a triton kernel, return the category of the kernel.
  35. Category can be one of:
  36. - pointwise
  37. - reduction
  38. - persistent_reduction
  39. Currently we simply decide the category depending on what decorator is imported
  40. by the kernel.
  41. """
  42. choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__]
  43. if len(choices) == 1:
  44. return choices[0]
  45. else:
  46. return "unknown"
  47. def get_triton_kernel(mod):
  48. from torch._inductor.runtime.triton_heuristics import CachingAutotuner
  49. cand_list = [
  50. v
  51. for k, v in mod.__dict__.items()
  52. if k.startswith("triton_") and isinstance(v, CachingAutotuner)
  53. ]
  54. assert len(cand_list) == 1
  55. return cand_list[0]
  56. def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
  57. """
  58. An experimental API used only when config.benchmark_kernel is true.
  59. Run the kernel benchmarks for all the kernels cached in PyCodeCache.
  60. Used in the compiled modules.
  61. Put this method here rather than codegen it for convenience since its implementation
  62. does not change based on different graph modules being compiled.
  63. """
  64. from torch._inductor.codecache import PyCodeCache
  65. nfound = 0
  66. for kernel_key, kernel_mod in PyCodeCache.cache.items():
  67. if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
  68. continue
  69. triton_kernel = get_triton_kernel(kernel_mod)
  70. kernel_category = get_kernel_category(kernel_mod)
  71. args = kernel_mod.get_args()
  72. num_in_out_ptrs = len(
  73. [
  74. arg_name
  75. for arg_name in triton_kernel.fn.arg_names
  76. if arg_name.startswith("in_out_ptr")
  77. ]
  78. )
  79. num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None)
  80. if num_gb is None:
  81. num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
  82. def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
  83. if not any(x is None for x in [n_regs, n_spills, shared]):
  84. kernel_detail_str = (
  85. f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem"
  86. )
  87. else:
  88. kernel_detail_str = ""
  89. gb_per_s = num_gb / (ms / 1e3)
  90. return create_bandwidth_info_str(
  91. ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str
  92. )
  93. kernel_desc = (
  94. f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
  95. )
  96. if benchmark_all_configs:
  97. assert hasattr(kernel_mod, "benchmark_all_configs")
  98. bench_result = kernel_mod.benchmark_all_configs(args)
  99. print(kernel_desc)
  100. for launcher, ms in bench_result.items():
  101. print(
  102. f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
  103. )
  104. else:
  105. ms = do_bench_gpu(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
  106. assert (
  107. len(triton_kernel.launchers) == 1
  108. ), "Autotuner should have selected the best config"
  109. launcher = triton_kernel.launchers[0]
  110. print(
  111. get_info_str(
  112. ms,
  113. launcher.n_regs,
  114. launcher.n_spills,
  115. launcher.shared,
  116. prefix=f"{kernel_desc} ",
  117. )
  118. )
  119. nfound += 1
  120. if nfound == 0:
  121. print(
  122. "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
  123. )
  124. @dataclasses.dataclass
  125. class ProfileEvent:
  126. category: str
  127. key: str
  128. self_cuda_time_ms: float
  129. # the benchmark is run multiple times and we average the count across all the
  130. # runs. It should be an integer but define a float just in case.
  131. count: float
  132. def parse_profile_event_list(benchmark_name, event_list, wall_time_ms, nruns):
  133. def get_self_cuda_time(ev):
  134. """
  135. ev.self_cuda_time_total is in microsecond. Convert to millisecond.
  136. """
  137. return ev.self_cuda_time_total / 1000 / nruns
  138. all_events = defaultdict(list)
  139. def add_event(ev, category):
  140. profile_ev = ProfileEvent(
  141. category=category,
  142. key=ev.key,
  143. self_cuda_time_ms=get_self_cuda_time(ev),
  144. count=ev.count / nruns, # average across all runs
  145. )
  146. all_events[category].append(profile_ev)
  147. for ev in event_list:
  148. assert not ev.is_legacy, "Don't support the legacy profiler"
  149. if ev.device_type == DeviceType.CPU:
  150. # ignore the event on CPU side
  151. continue
  152. category = "unknown"
  153. if ev.key.startswith("triton_"):
  154. if ev.key.startswith("triton_poi"):
  155. category = "triton_pointwise"
  156. elif ev.key.startswith("triton_red"):
  157. category = "triton_reduction"
  158. elif ev.key.startswith("triton_per"):
  159. category = "triton_persistent_reduction"
  160. else:
  161. category = "triton_unknown"
  162. add_event(ev, category)
  163. def report_category(category, profile_events):
  164. from tabulate import tabulate
  165. profile_events.sort(key=lambda ev: ev.self_cuda_time_ms, reverse=True)
  166. rows = []
  167. total_time = 0.0
  168. print(f"\n == {category} category kernels == ")
  169. for ev in profile_events:
  170. total_time += ev.self_cuda_time_ms
  171. percent = f"{ev.self_cuda_time_ms / wall_time_ms * 100:.2f}%"
  172. rows.append([ev.key[:120], ev.self_cuda_time_ms, ev.count, percent])
  173. rows.append(
  174. ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"]
  175. )
  176. print(
  177. tabulate(
  178. rows, headers=["Kernel", "Self CUDA TIME (ms)", "Count", "Percent"]
  179. )
  180. )
  181. return total_time
  182. def report():
  183. category_list = [
  184. "triton_pointwise",
  185. "triton_reduction",
  186. "triton_persistent_reduction",
  187. "triton_unknown",
  188. "unknown",
  189. ]
  190. assert set(all_events.keys()).issubset(
  191. set(category_list)
  192. ), f"{list(all_events.keys())}"
  193. per_category_wall_time = {}
  194. total_cuda_ms = 0.0
  195. for category in category_list:
  196. if category in all_events:
  197. _time = report_category(category, all_events[category])
  198. per_category_wall_time[category] = _time
  199. total_cuda_ms += _time
  200. gpu_busy_percent = f"{total_cuda_ms / wall_time_ms * 100:.2f}%"
  201. print(f"\nPercent of time when GPU is busy: {gpu_busy_percent}")
  202. print(f"Total wall time {wall_time_ms:.3f} ms")
  203. # output such a line so we can gather such line from all compiled modules from all
  204. # benchmarks and tabulate it!
  205. # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent,
  206. # unknown_category_percent, GPU_busy_percent, wall_time_ms
  207. tabulate_line = f"Output for tabulate: {benchmark_name}"
  208. for category in category_list:
  209. percent = (
  210. f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%"
  211. )
  212. tabulate_line += f", {percent}"
  213. tabulate_line += f", {gpu_busy_percent}, {wall_time_ms:.3f}ms"
  214. print(tabulate_line)
  215. report()
  216. def compiled_module_main(benchmark_name, benchmark_compiled_module_fn):
  217. """
  218. This is the function called in __main__ block of a compiled module.
  219. """
  220. import argparse
  221. parser = argparse.ArgumentParser()
  222. parser.add_argument(
  223. "--benchmark-kernels",
  224. "-k",
  225. action="store_true",
  226. help="Whether to benchmark each individual kernels",
  227. )
  228. parser.add_argument(
  229. "--benchmark-all-configs",
  230. "-c",
  231. action="store_true",
  232. help="Whether to benchmark each individual config for a kernel",
  233. )
  234. parser.add_argument(
  235. "--profile",
  236. "-p",
  237. action="store_true",
  238. help="Whether to profile the compiled module",
  239. )
  240. args = parser.parse_args()
  241. if args.benchmark_kernels:
  242. benchmark_all_kernels(benchmark_name, args.benchmark_all_configs)
  243. else:
  244. times = 10
  245. repeat = 10
  246. wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000
  247. if not args.profile:
  248. return
  249. with torch.profiler.profile(record_shapes=True) as p:
  250. benchmark_compiled_module_fn(times=times, repeat=repeat)
  251. path = f"{tempfile.gettempdir()}/compiled_module_profile.json"
  252. p.export_chrome_trace(path)
  253. print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
  254. print(f"Chrome trace for the profile is written to {path}")
  255. event_list = p.key_averages(group_by_input_shape=True)
  256. print(event_list.table(sort_by="self_cuda_time_total", row_limit=10))
  257. parse_profile_event_list(
  258. benchmark_name, event_list, wall_time_ms, times * repeat
  259. )