async_compile.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import functools
  4. import logging
  5. import multiprocessing
  6. import os
  7. import sys
  8. from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
  9. from functools import partial
  10. from time import time
  11. from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
  12. import torch
  13. from torch._dynamo.device_interface import get_registered_device_interfaces
  14. from torch._inductor import config
  15. from torch._inductor.codecache import (
  16. CodeCacheFuture,
  17. CppCodeCache,
  18. CppPythonBindingsCodeCache,
  19. CUDACodeCache,
  20. HalideCodeCache,
  21. LambdaFuture,
  22. TritonCodeCache,
  23. TritonFuture,
  24. )
  25. from torch._inductor.compile_worker.subproc_pool import (
  26. _warm_process_pool,
  27. AnyPool,
  28. SubprocPool,
  29. )
  30. from torch._inductor.compile_worker.watchdog import _async_compile_initializer
  31. from torch._inductor.runtime.compile_tasks import (
  32. _set_triton_ptxas_path,
  33. _worker_compile_triton,
  34. )
  35. from torch.hub import _Faketqdm, tqdm
  36. if TYPE_CHECKING:
  37. from torch._inductor.runtime.hints import HalideMeta
  38. # timing metrics for time spent in the compilation
  39. _cumulative_compile_time = 0.0
  40. _t0: Optional[float] = None
  41. kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
  42. def pre_fork_setup():
  43. """
  44. Setup that must be done prior to forking with a process pool.
  45. """
  46. # ensure properties have been calculated before processes
  47. # are forked
  48. caching_device_properties()
  49. # Computing the triton key can be slow. If we call it before fork,
  50. # it will be cached for the forked subprocesses.
  51. try:
  52. from triton.compiler.compiler import triton_key
  53. triton_key()
  54. except ModuleNotFoundError:
  55. # Might not be installed.
  56. pass
  57. def caching_device_properties():
  58. for _, device_interface in get_registered_device_interfaces():
  59. if device_interface.is_available():
  60. device_interface.Worker.get_device_properties()
  61. def _compile_start() -> None:
  62. global _t0
  63. if _t0 is None:
  64. _t0 = time()
  65. def _compile_end() -> None:
  66. global _cumulative_compile_time, _t0
  67. if _t0 is not None:
  68. t1 = time()
  69. _cumulative_compile_time += t1 - _t0
  70. _t0 = None
  71. # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
  72. _IS_WINDOWS = sys.platform == "win32"
  73. log = logging.getLogger(__name__)
  74. # Used to keep track of all process pools invoked so far.
  75. _pool_set: Set[AnyPool] = set()
  76. def shutdown_compile_workers() -> None:
  77. """Shut down all outstanding compile-worker pools."""
  78. for pool in _pool_set:
  79. pool.shutdown()
  80. after_fork()
  81. def after_fork():
  82. """Reset pools to initial state without shutting them down"""
  83. _pool_set.clear()
  84. AsyncCompile.process_pool.cache_clear()
  85. try:
  86. os.register_at_fork(after_in_child=after_fork)
  87. except AttributeError:
  88. pass # register_at_fork does not exists on windows
  89. class AsyncCompile:
  90. def __init__(self) -> None:
  91. pass
  92. @staticmethod
  93. @functools.lru_cache(1)
  94. def pool() -> ThreadPoolExecutor:
  95. assert config.compile_threads > 1
  96. return ThreadPoolExecutor(config.compile_threads)
  97. @staticmethod
  98. @functools.lru_cache(1)
  99. def process_pool() -> AnyPool:
  100. assert config.compile_threads > 1
  101. pool: AnyPool
  102. if config.worker_start_method == "subprocess":
  103. # Wrapper around ProcessPoolExecutor forks in a new process we control
  104. pool = SubprocPool(config.compile_threads)
  105. else:
  106. pre_fork_setup()
  107. ctx = multiprocessing.get_context(config.worker_start_method)
  108. pool = ProcessPoolExecutor(
  109. config.compile_threads,
  110. mp_context=ctx,
  111. initializer=partial(_async_compile_initializer, os.getpid()),
  112. )
  113. # when this pool is created in a subprocess object, the normal exit handler
  114. # doesn't run, and we need to register our own handler.
  115. # exitpriority has to be high, because another one of the finalizers will
  116. # kill the worker thread that sends the shutdown message to the workers...
  117. multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
  118. _pool_set.add(pool)
  119. return pool
  120. @classmethod
  121. def warm_pool(cls) -> None:
  122. if config.compile_threads <= 1:
  123. return
  124. _compile_start()
  125. _warm_process_pool(cls.process_pool(), config.compile_threads)
  126. _compile_end()
  127. @classmethod
  128. def submit(cls, task: Callable[..., Any]) -> Any:
  129. if config.compile_threads <= 1:
  130. return task()
  131. return cls.pool().submit(task)
  132. def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
  133. kernel_code_log.info("Triton Kernel:\n%s", source_code)
  134. _compile_start()
  135. _set_triton_ptxas_path()
  136. kernel = TritonCodeCache.load(kernel_name, source_code)
  137. if config.compile_threads > 1:
  138. return TritonFuture(
  139. kernel,
  140. self.process_pool().submit(
  141. _worker_compile_triton,
  142. kernel._reload_in_subproc,
  143. ),
  144. )
  145. else:
  146. kernel.precompile()
  147. return kernel
  148. def multi_kernel(self, *args, **kwargs) -> Any:
  149. from torch._inductor.codegen.multi_kernel import MultiKernelCall
  150. # no need to call this in parallel since the sub-kernels are already parallel tasks
  151. return MultiKernelCall(*args, **kwargs)
  152. def cpp(self, source_code: str):
  153. kernel_code_log.info("CPP Kernel:\n%s", source_code)
  154. if config.compile_threads <= 1:
  155. return CppCodeCache.load(source_code).kernel
  156. else:
  157. get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
  158. return LambdaFuture(lambda: get_result().kernel)
  159. def cpp_pybinding(self, argtypes: List[str], source_code: str):
  160. kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
  161. if config.compile_threads <= 1:
  162. return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
  163. else:
  164. get_result = CppPythonBindingsCodeCache.load_pybinding_async(
  165. argtypes, source_code, submit_fn=self.submit
  166. )
  167. return LambdaFuture(get_result)
  168. def cuda(self, source_code, dst_file_ext):
  169. kernel_code_log.info("CUDA Kernel:\n%s", source_code)
  170. def task():
  171. return CUDACodeCache.load(source_code, dst_file_ext)[0]
  172. return self.submit(task)
  173. def halide(self, meta: HalideMeta, source_code: str):
  174. kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
  175. if config.compile_threads <= 1:
  176. return HalideCodeCache.generate_halide(meta, source_code)
  177. else:
  178. get_result = HalideCodeCache.generate_halide_async(
  179. meta, source_code, submit_fn=self.submit
  180. )
  181. return LambdaFuture(get_result)
  182. def wait(self, scope: Dict[str, Any]) -> None:
  183. num_kernels = len(
  184. [
  185. value
  186. for key, value in scope.items()
  187. if isinstance(value, (Future, CodeCacheFuture))
  188. ]
  189. )
  190. pbar = tqdm(
  191. total=num_kernels,
  192. desc="Inductor Compilation",
  193. disable=config.disable_progress,
  194. delay=0,
  195. )
  196. if config.compile_threads > 1:
  197. for key, result in scope.items():
  198. if config.verbose_progress and not isinstance(pbar, _Faketqdm):
  199. pbar.set_postfix_str(key)
  200. if isinstance(result, (Future, CodeCacheFuture)):
  201. scope[key] = result.result()
  202. pbar.update(1)
  203. _compile_end()
  204. if (
  205. os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
  206. or os.environ.get("TORCH_WARM_POOL", "1") != "1"
  207. ):
  208. pass
  209. else:
  210. AsyncCompile.warm_pool()