autotune_process.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import ctypes
  5. import dataclasses
  6. import functools
  7. import logging
  8. import os
  9. import queue
  10. import time
  11. import warnings
  12. from concurrent.futures import ThreadPoolExecutor
  13. from ctypes import byref, c_size_t, c_void_p, CDLL
  14. from typing import (
  15. Any,
  16. Callable,
  17. Dict,
  18. Iterable,
  19. List,
  20. Optional,
  21. Sequence,
  22. TYPE_CHECKING,
  23. Union,
  24. )
  25. import torch
  26. import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
  27. from torch import multiprocessing
  28. from torch._dynamo.testing import rand_strided
  29. from torch._inductor import ir
  30. from torch._inductor.codecache import (
  31. CppCodeCache,
  32. CUDACodeCache,
  33. DLLWrapper,
  34. get_hash,
  35. PyCodeCache,
  36. )
  37. if TYPE_CHECKING:
  38. from multiprocessing.process import BaseProcess
  39. from multiprocessing.queues import Queue
  40. from types import ModuleType
  41. from torch._inductor.select_algorithm import TritonTemplateCaller
  42. from . import config
  43. from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
  44. from .virtualized import V
  45. CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
  46. EXIT_HANDLER_REGISTERED = False
  47. log = logging.getLogger(__name__)
  48. # Used to synchronize between parent and child processes
  49. class Ping:
  50. pass
  51. class Pong:
  52. pass
  53. class NonzeroWorkspaceNotSupportedError(Exception):
  54. pass
  55. @contextlib.contextmanager
  56. def set_cuda_visible_device(device: Optional[int]):
  57. """
  58. Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
  59. specified single device. If device is None, don't manipulate the environment.
  60. """
  61. if device is None:
  62. yield
  63. return
  64. current = os.environ.get(CUDA_VISIBLE_DEVICES)
  65. os.environ[CUDA_VISIBLE_DEVICES] = str(device)
  66. try:
  67. yield
  68. finally:
  69. if current is None:
  70. del os.environ[CUDA_VISIBLE_DEVICES]
  71. else:
  72. os.environ[CUDA_VISIBLE_DEVICES] = current
  73. @dataclasses.dataclass
  74. class TuningProcess:
  75. """
  76. Abstraction for launching a helper process to benchmark kernels. Spawns
  77. the parent process and uses multiprocessing queues to send benchmark
  78. requests and return results.
  79. """
  80. device: Optional[int] = None
  81. process: Optional[BaseProcess] = None
  82. request_queue: Optional[Queue[Any]] = None
  83. response_queue: Optional[Queue[Any]] = None
  84. @staticmethod
  85. def process_main(
  86. request_queue: Queue[Any],
  87. response_queue: Queue[Any],
  88. ) -> None:
  89. """
  90. Entry point for the child process.
  91. """
  92. log.debug(
  93. "Entering TuningProcess child. Visible devices = %s",
  94. os.environ.get(CUDA_VISIBLE_DEVICES),
  95. )
  96. try:
  97. TuningProcess.workloop(request_queue, response_queue)
  98. except Exception as ex:
  99. log.exception("Exception in TuningProcess")
  100. @staticmethod
  101. def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
  102. """
  103. Work loop for the benchmarking subprocess.
  104. """
  105. while True:
  106. obj = request_queue.get()
  107. if obj is None:
  108. break # None is a sentinel for the child to terminate
  109. elif isinstance(obj, Ping):
  110. response_queue.put(Pong())
  111. elif isinstance(obj, BenchmarkRequest):
  112. response_queue.put(obj.benchmark())
  113. else:
  114. raise RuntimeError(f"Invalid request type {type(obj)}")
  115. def valid(self) -> bool:
  116. """
  117. True if the sub-process has been initialized.
  118. """
  119. return (
  120. self.process is not None
  121. and self.request_queue is not None
  122. and self.response_queue is not None
  123. )
  124. def clear(self) -> None:
  125. """
  126. Reset to an uninitialized state.
  127. """
  128. self.process = self.request_queue = self.response_queue = None
  129. def initialize(self) -> None:
  130. """
  131. Create child process, request/response queues, and do the warm up.
  132. Set the environment to make only the provided GPU device visible
  133. to the process.
  134. """
  135. if self.valid():
  136. return
  137. # cuda runtime does not work with "fork", use "spawn" to start processes.
  138. ctx = multiprocessing.get_context("spawn")
  139. self.request_queue = ctx.Queue()
  140. self.response_queue = ctx.Queue()
  141. self.process = ctx.Process(
  142. target=self.process_main,
  143. args=(
  144. self.request_queue,
  145. self.response_queue,
  146. ),
  147. )
  148. assert self.process is not None
  149. with set_cuda_visible_device(self.device):
  150. self.process.start()
  151. def put(self, obj: Any) -> None:
  152. """
  153. Push a work item to the child process.
  154. """
  155. # In case of a prior crash, ensure the subprocess is running
  156. self.initialize()
  157. assert self.request_queue is not None
  158. self.request_queue.put(obj)
  159. def get(
  160. self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
  161. ) -> Any:
  162. """
  163. Get a response from the child process. Raises queue.Empty on timeout
  164. or if the process dies.
  165. This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
  166. to populate the timeouts:
  167. Arguments:
  168. @param result_timeout: Timeout in seconds, defaults to 120.0 or to
  169. config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
  170. @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
  171. Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
  172. @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
  173. remains alive. Defaults to 1.0 or to
  174. config.max_autotune_subproc_terminate_timeout_seconds.
  175. Returns:
  176. A response from the child process (Any type)
  177. """
  178. assert self.process is not None
  179. assert self.response_queue is not None
  180. while True:
  181. try:
  182. remaining_timeout = result_timeout
  183. res = None
  184. while remaining_timeout is not None and remaining_timeout >= 1.0:
  185. remaining_timeout -= 0.5
  186. try:
  187. res = self.response_queue.get(timeout=0.5)
  188. break
  189. except queue.Empty:
  190. if not self.process.is_alive():
  191. raise # is being caught a few lines below
  192. if res is None:
  193. res = self.response_queue.get(timeout=remaining_timeout)
  194. return res
  195. except queue.Empty:
  196. status = self.process.exitcode
  197. if status is None:
  198. self.kill(
  199. graceful_timeout=graceful_timeout,
  200. terminate_timeout=terminate_timeout,
  201. )
  202. else:
  203. # child process crashed
  204. self.clear()
  205. raise
  206. def terminate(self) -> None:
  207. """
  208. Signal the child process to terminate.
  209. """
  210. if self.valid():
  211. assert self.process is not None
  212. assert self.request_queue is not None
  213. self.request_queue.put(None)
  214. def wait(self) -> None:
  215. """
  216. Wait for the child process to exit.
  217. """
  218. if self.process is not None:
  219. self.process.join()
  220. self.clear()
  221. def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
  222. # Tries to kill the process, using a graceful_timeout in which the process
  223. # is allowed to exit gracefully. If the process is still alive,
  224. # it will be terminated. If that is not sufficient to end it
  225. # within terminate_timeout seconds, it will be killed.
  226. if self.process is not None:
  227. self.terminate()
  228. self.process.join(timeout=graceful_timeout)
  229. if self.process.is_alive():
  230. log.warning(
  231. "Sending SIGTERM to process with PID %d",
  232. self.process.pid,
  233. )
  234. self.process.terminate()
  235. self.process.join(timeout=terminate_timeout)
  236. if self.process.is_alive():
  237. log.error(
  238. "Sending SIGKILL to process with PID %d",
  239. self.process.pid,
  240. )
  241. self.process.kill() # This should definitely end the process
  242. self.clear()
  243. @dataclasses.dataclass
  244. class TuningProcessPool:
  245. """
  246. Maintains a pool of TuningProcesses to benchmark kernels in parallel
  247. across devices. By default, we create one TuningProcess per device and
  248. set the sub-process environment to make only that device visible.
  249. """
  250. processes: Optional[queue.Queue[TuningProcess]] = None
  251. executor: Optional[ThreadPoolExecutor] = None
  252. def initialize(self) -> None:
  253. """
  254. Start the child processes.
  255. """
  256. assert (self.processes is None) == (self.executor is None)
  257. if self.processes is not None:
  258. return
  259. devices = self.get_device_list()
  260. log.debug("Sub-process autotune device list: %s", devices)
  261. # Launch the child processes and push a msg to "warm up"
  262. self.processes = queue.Queue()
  263. for device in devices:
  264. p = TuningProcess(device=device)
  265. p.initialize()
  266. p.put(Ping())
  267. self.processes.put(p)
  268. # Wait for the initialization to finish
  269. for p in self.processes.queue:
  270. assert isinstance(p.get(result_timeout=None), Pong)
  271. # Use a thread pool to manage distributing work to the subprocesses.
  272. # Threads block on an available process, so it makes sense to match
  273. # the number of threads with the number of devices.
  274. self.executor = ThreadPoolExecutor(max_workers=len(devices))
  275. # Register the exit handler for the parent process so it will terminate
  276. # the child processes.
  277. global EXIT_HANDLER_REGISTERED
  278. if not EXIT_HANDLER_REGISTERED:
  279. EXIT_HANDLER_REGISTERED = True
  280. import atexit
  281. atexit.register(self.terminate)
  282. def get_device_list(self) -> Sequence[Optional[int]]:
  283. """
  284. Gather the list of devices to be used in the pool.
  285. """
  286. if not config.autotune_multi_device:
  287. # Don't use multiple devices
  288. return [None]
  289. count = torch.cuda.device_count()
  290. # If the user specified the visible devices in the env, use those.
  291. if CUDA_VISIBLE_DEVICES in os.environ:
  292. devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
  293. assert len(devices) <= count
  294. return devices
  295. return list(range(count))
  296. def terminate(self) -> None:
  297. """
  298. Signal all child processes to terminate.
  299. """
  300. if self.executor is not None:
  301. self.executor.shutdown()
  302. self.executor = None
  303. if self.processes is not None:
  304. for p in self.processes.queue:
  305. p.terminate()
  306. for p in self.processes.queue:
  307. p.wait()
  308. self.processes = None
  309. def target(self, choice: TritonTemplateCaller) -> float:
  310. """
  311. Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
  312. remove it from the queue, execute the benchmark in that subprocess, and return
  313. the TuningProcess to the queue.
  314. """
  315. assert choice.bmreq is not None
  316. assert self.processes is not None
  317. process = self.processes.get()
  318. process.put(choice.bmreq)
  319. try:
  320. return process.get(
  321. config.max_autotune_subproc_result_timeout_seconds,
  322. config.max_autotune_subproc_graceful_timeout_seconds,
  323. config.max_autotune_subproc_terminate_timeout_seconds,
  324. )
  325. except queue.Empty:
  326. warnings.warn(
  327. f"Failed to benchmark choice '{choice}'. It will be ignored. "
  328. "Please debug the root cause in case the choice can bring perf gains."
  329. )
  330. # set to INF so this choice will be ignored
  331. return float("inf")
  332. finally:
  333. self.processes.put(process)
  334. def benchmark(
  335. self,
  336. choices: List[TritonTemplateCaller],
  337. ) -> Dict[TritonTemplateCaller, float]:
  338. """
  339. Benchmark each choice in a separate process.
  340. """
  341. assert self.processes is not None, "Tuning process pool is not initialized"
  342. assert self.executor is not None
  343. results = {}
  344. # Use a ThreadExecutorPool to spread the work across the subprocesses and
  345. # to grab subprocesses as soon as they're free.
  346. for choice, result in zip(choices, self.executor.map(self.target, choices)):
  347. results[choice] = result
  348. return results
  349. tuning_pool = TuningProcessPool()
  350. LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
  351. @dataclasses.dataclass
  352. class TensorMeta:
  353. device: torch.device
  354. dtype: torch.dtype
  355. sizes: torch._prims_common.ShapeType
  356. strides: torch._prims_common.StrideType
  357. offset: int
  358. name: Optional[str] = None
  359. @classmethod
  360. def from_irnodes(
  361. cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
  362. ) -> Union[TensorMeta, List[TensorMeta]]:
  363. if isinstance(irnodes, Sequence):
  364. result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
  365. assert all(isinstance(x, TensorMeta) for x in result)
  366. return result
  367. node = irnodes
  368. if isinstance(node, ir.Layout):
  369. node = ir.Buffer("fake", node)
  370. dtype = node.get_dtype()
  371. assert dtype is not None
  372. return TensorMeta(
  373. device=node.get_device(),
  374. dtype=dtype,
  375. sizes=V.graph.sizevars.size_hints(
  376. node.get_size(),
  377. fallback=config.unbacked_symint_fallback,
  378. ),
  379. strides=V.graph.sizevars.size_hints(
  380. node.get_stride(),
  381. fallback=config.unbacked_symint_fallback,
  382. ),
  383. offset=V.graph.sizevars.size_hint(
  384. node.get_layout().offset,
  385. fallback=config.unbacked_symint_fallback,
  386. ),
  387. name=node.get_name(),
  388. )
  389. def to_tensor(self) -> torch.Tensor:
  390. return rand_strided(
  391. self.sizes,
  392. self.strides,
  393. device=self.device,
  394. dtype=self.dtype,
  395. extra_size=self.offset,
  396. )
  397. @dataclasses.dataclass
  398. class BenchmarkRequest:
  399. """
  400. Only handle triton template benchmark for now. The extern kernel benchmark
  401. can be done inside the same process since they usually don't cause crash.
  402. Important: Instances of this class and subclasses have to be serializable
  403. across process boundaries. Do not put CUDA Tensors in here!
  404. """
  405. def __init__(
  406. self,
  407. kernel_name: str,
  408. input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  409. output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  410. extra_args: Iterable[Any],
  411. ):
  412. # the kernel name defined in the module
  413. self.kernel_name = kernel_name
  414. if isinstance(input_tensor_meta, TensorMeta):
  415. input_tensor_meta = [input_tensor_meta]
  416. self.input_tensor_meta = input_tensor_meta
  417. if isinstance(output_tensor_meta, (tuple, list)):
  418. assert len(output_tensor_meta) == 1
  419. output_tensor_meta = output_tensor_meta[0]
  420. self.output_tensor_meta = output_tensor_meta
  421. self.extra_args = extra_args
  422. def make_run_fn(
  423. self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
  424. ) -> Callable[[], None]:
  425. raise NotImplementedError
  426. def cleanup_run_fn(self) -> None:
  427. pass
  428. def do_bench(
  429. self,
  430. fn,
  431. *input_tensors: torch.Tensor,
  432. output_tensor: Optional[torch.Tensor] = None,
  433. ) -> float:
  434. raise NotImplementedError
  435. def benchmark(
  436. self,
  437. *input_tensors: torch.Tensor,
  438. output_tensor: Optional[torch.Tensor] = None,
  439. ) -> float:
  440. debug = log.isEnabledFor(logging.DEBUG)
  441. if debug:
  442. start_ts = time.time()
  443. # create args and out tensor
  444. if output_tensor is None:
  445. assert len(input_tensors) == 0
  446. input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
  447. output_tensor = self.output_tensor_meta.to_tensor()
  448. if debug:
  449. create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
  450. start_ts = time.time()
  451. try:
  452. fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
  453. except NonzeroWorkspaceNotSupportedError:
  454. # Skipping all ops with nonzero workspace requirements
  455. log.info("Skipping op due to nonzero workspace requirement")
  456. return float("inf")
  457. if debug:
  458. load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
  459. start_ts = time.time()
  460. out = self.do_bench(fn, *input_tensors, output_tensor)
  461. if debug:
  462. bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
  463. log.debug(
  464. "InChildProcess %s: load %f, create tensor %f, bench %f",
  465. str(self),
  466. load_elapse, # type: ignore[possibly-undefined]
  467. create_tensor_elapse, # type: ignore[possibly-undefined]
  468. bench_elapse,
  469. )
  470. self.cleanup_run_fn()
  471. return out
  472. class TestBenchmarkRequest(BenchmarkRequest):
  473. """
  474. Supports unit testing. Defined in this file so that the TuningProcess
  475. sub-process knows how to unpickle these objects.
  476. """
  477. def __init__(self, value: Optional[float] = None) -> None:
  478. self.value = value
  479. def benchmark(
  480. self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
  481. ) -> float:
  482. if self.value is None:
  483. raise Exception("Failed to run") # noqa: TRY002
  484. return self.value
  485. class GPUDeviceBenchmarkRequest(BenchmarkRequest):
  486. def do_bench(
  487. self,
  488. fn,
  489. *input_tensors: torch.Tensor,
  490. output_tensor: Optional[torch.Tensor] = None,
  491. ) -> float:
  492. device_idx_set = {
  493. tensor.device.index
  494. for tensor in [*input_tensors, output_tensor]
  495. if isinstance(tensor, torch.Tensor)
  496. and tensor.is_cuda
  497. and tensor.device.index is not None
  498. }
  499. assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
  500. if len(device_idx_set) == 1:
  501. device_idx = next(iter(device_idx_set))
  502. else:
  503. device_idx = torch.cuda.current_device()
  504. with torch.cuda.device(device_idx):
  505. out = do_bench_gpu(fn)
  506. torch.cuda.synchronize() # shake out any CUDA errors
  507. return out
  508. class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
  509. # Important: Instances of this class have to be serializable
  510. # across process boundaries. Do not put CUDA Tensors in here!
  511. def __init__(
  512. self,
  513. kernel_name: str,
  514. input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  515. output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  516. extra_args: Iterable[Any],
  517. module_path: str, # the path of the module defining the triton kernel
  518. module_cache_key: str,
  519. grid: List[int],
  520. num_stages: int,
  521. num_warps: int,
  522. matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
  523. ):
  524. super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
  525. self.module_path = module_path
  526. self.module_cache_key = module_cache_key
  527. self.grid = grid
  528. self.num_stages = num_stages
  529. self.num_warps = num_warps
  530. self.matrix_instr_nonkdim = matrix_instr_nonkdim
  531. def make_run_fn(
  532. self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
  533. ) -> Callable[[], None]:
  534. mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
  535. log.debug(
  536. "benchmark module key: %s, path: %s",
  537. self.module_cache_key,
  538. self.module_path,
  539. )
  540. run_method = getattr(mod, self.kernel_name).run
  541. extra_args = list(self.extra_args)
  542. # Newer version of triton add warmup argument to JITFunction.run.
  543. # This code handles backward-compatibility.
  544. warmup_arg = {}
  545. import inspect
  546. if "warmup" in inspect.signature(run_method).parameters:
  547. warmup_arg["warmup"] = False
  548. from torch._C import _cuda_getCurrentRawStream as get_raw_stream
  549. if torch.version.hip and self.matrix_instr_nonkdim != 0:
  550. return functools.partial(
  551. run_method,
  552. *input_tensors,
  553. output_tensor,
  554. *self.extra_args,
  555. grid=self.grid,
  556. **warmup_arg,
  557. stream=get_raw_stream(self.output_tensor_meta.device.index),
  558. )
  559. else:
  560. return functools.partial(
  561. run_method,
  562. *input_tensors,
  563. output_tensor,
  564. *self.extra_args,
  565. grid=self.grid,
  566. **warmup_arg,
  567. stream=get_raw_stream(self.output_tensor_meta.device.index),
  568. )
  569. def precompile(self):
  570. mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
  571. getattr(mod, self.kernel_name).precompile()
  572. def __str__(self) -> str:
  573. return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
  574. class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
  575. # Important: Instances of this class have to be serializable
  576. # across process boundaries. Do not put CUDA Tensors in here!
  577. def __init__(
  578. self,
  579. kernel_name: str,
  580. input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  581. output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  582. extra_args: Iterable[Any],
  583. source_code: str,
  584. ):
  585. super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
  586. self.source_code = source_code
  587. self.workspace_size: int = 0
  588. self.workspace: Optional[torch.Tensor] = None
  589. self.DLL: Optional[DLLWrapper] = None
  590. self._workspace_size_updated = False
  591. self.hash_key: str = ""
  592. self.source_file: str = ""
  593. self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
  594. def precompile(self):
  595. # Prepopulate CUDACodeCache
  596. # may happen in separate Threadpool
  597. log.debug("Precompiling %s", self)
  598. CUDACodeCache.compile(self.source_code, "so")
  599. log.debug("Done precompiling %s", self)
  600. def make_run_fn(
  601. self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
  602. ) -> Callable[[], None]:
  603. self.ensure_dll_loaded()
  604. self.update_workspace_size()
  605. args = [
  606. c_void_p(tensor.data_ptr())
  607. for tensor in list(input_tensors) + [output_tensor]
  608. ]
  609. log.debug(
  610. "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
  611. self.kernel_name,
  612. self.source_file,
  613. self.hash_key,
  614. self.DLL,
  615. args,
  616. self.extra_args,
  617. )
  618. stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
  619. run_method = getattr(self.DLL, self.kernel_name)
  620. workspace_ptr = c_void_p(0)
  621. if self.workspace_size > 0:
  622. self.workspace = torch.zeros(
  623. (self.workspace_size + 7) // 8,
  624. dtype=torch.float64,
  625. device=output_tensor.device,
  626. )
  627. workspace_ptr = c_void_p(self.workspace.data_ptr())
  628. # Generate partial function.
  629. return functools.partial(
  630. run_method,
  631. *args,
  632. *self.extra_args,
  633. None, # null workspace size ptr
  634. workspace_ptr, # set workspace ptr,
  635. stream_ptr,
  636. )
  637. def update_workspace_size(self) -> None:
  638. if self._workspace_size_updated:
  639. return
  640. self.ensure_dll_loaded()
  641. unique_input_count = len({meta.name for meta in self.input_tensor_meta})
  642. args = [c_void_p(None) for _ in range(unique_input_count + 1)]
  643. stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
  644. run_method = getattr(self.DLL, self.kernel_name)
  645. # Retrieve workspace_size and initialize workspace.
  646. c_workspace_size = c_size_t()
  647. run_method(
  648. *args, # input ptrs and output ptrs
  649. *self.extra_args,
  650. byref(
  651. c_workspace_size
  652. ), # set workspace size ptr to retrieve workspace size
  653. None, # null workspace ptr
  654. stream_ptr,
  655. )
  656. torch.cuda.synchronize() # shake out any CUDA errors
  657. self.workspace_size = c_workspace_size.value
  658. log.debug(
  659. "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
  660. self.workspace_size,
  661. self.kernel_name,
  662. self.source_file,
  663. self.hash_key,
  664. self.DLL,
  665. args,
  666. self.extra_args,
  667. )
  668. self._workspace_size_updated = True
  669. def ensure_dll_loaded(self):
  670. if self.DLL is None:
  671. self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
  672. self.source_code, "so"
  673. )
  674. def cleanup_run_fn(self) -> None:
  675. if self.DLL is not None:
  676. self.DLL.close()
  677. self.workspace = None
  678. def __str__(self) -> str:
  679. return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
  680. class CPUDeviceBenchmarkRequest(BenchmarkRequest):
  681. def do_bench(
  682. self,
  683. fn,
  684. *input_tensors: torch.Tensor,
  685. output_tensor: Optional[torch.Tensor] = None,
  686. ) -> float:
  687. return do_bench_cpu(fn)
  688. class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
  689. # Important: Instances of this class have to be serializable
  690. # across process boundaries. Do not put Tensors in here!
  691. def __init__(
  692. self,
  693. kernel_name: str,
  694. input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  695. output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
  696. extra_args: Iterable[Any],
  697. source_code: str,
  698. ):
  699. super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
  700. self.source_code = source_code
  701. self.hash_key = get_hash(source_code)
  702. self.DLL: Optional[Union[CDLL, ModuleType]] = None
  703. def precompile(self):
  704. # Prepopulate CppCodeCache
  705. # may happen in separate Threadpool
  706. log.debug("Precompiling %s", self)
  707. CppCodeCache.load(self.source_code, cuda=False)
  708. log.debug("Done precompiling %s", self)
  709. def make_run_fn(
  710. self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
  711. ) -> Callable[[], None]:
  712. # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
  713. self.DLL = CppCodeCache.load(self.source_code, cuda=False)
  714. args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
  715. log.debug(
  716. "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
  717. self.kernel_name,
  718. self.DLL,
  719. args,
  720. self.extra_args,
  721. )
  722. run_method = getattr(self.DLL, self.kernel_name)
  723. run_method.argtypes = [ctypes.c_ulonglong] * len(args)
  724. # Generate partial function.
  725. return functools.partial(
  726. run_method,
  727. *args,
  728. *self.extra_args,
  729. )
  730. def cleanup_run_fn(self) -> None:
  731. if self.DLL is not None:
  732. self.DLL.close()
  733. def __str__(self) -> str:
  734. return f"{self.kernel_name=}"
  735. def benchmark_in_sub_process(
  736. choices: List[TritonTemplateCaller],
  737. ) -> Dict[TritonTemplateCaller, float]:
  738. """
  739. Do benchmarking in a subprocess and return the perf number (latency).
  740. """
  741. return tuning_pool.benchmark(choices)