device_interface.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
  4. import torch
  5. from torch._streambase import _EventBase, _StreamBase
  6. get_cuda_stream: Optional[Callable[[int], int]]
  7. if torch.cuda._is_compiled():
  8. from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
  9. else:
  10. get_cuda_stream = None
  11. _device_t = Union[torch.device, str, int, None]
  12. # Recording the device properties in the main process but used in worker process.
  13. caching_worker_device_properties: Dict[str, Any] = {}
  14. caching_worker_current_devices: Dict[str, int] = {}
  15. class DeviceInterfaceMeta(type):
  16. def __new__(metacls, *args, **kwargs):
  17. class_member = args[2]
  18. if "Event" in class_member:
  19. assert inspect.isclass(class_member["Event"]) and issubclass(
  20. class_member["Event"], _EventBase
  21. ), "DeviceInterface member Event should be inherit from _EventBase"
  22. if "Stream" in class_member:
  23. assert inspect.isclass(class_member["Stream"]) and issubclass(
  24. class_member["Stream"], _StreamBase
  25. ), "DeviceInterface member Stream should be inherit from _StreamBase"
  26. return super().__new__(metacls, *args, **kwargs)
  27. class DeviceInterface(metaclass=DeviceInterfaceMeta):
  28. """
  29. This is a simple device runtime interface for Inductor. It enables custom
  30. backends to be integrated with Inductor in a device-agnostic semantic.
  31. """
  32. class device:
  33. def __new__(cls, device: _device_t):
  34. raise NotImplementedError
  35. class Worker:
  36. """
  37. Worker API to query device properties that will work in multi processing
  38. workers that cannot use the GPU APIs (due to processing fork() and
  39. initialization time issues). Properties are recorded in the main process
  40. before we fork the workers.
  41. """
  42. @staticmethod
  43. def set_device(device: int):
  44. raise NotImplementedError
  45. @staticmethod
  46. def current_device() -> int:
  47. raise NotImplementedError
  48. @staticmethod
  49. def get_device_properties(device: _device_t = None):
  50. raise NotImplementedError
  51. @staticmethod
  52. def current_device():
  53. raise NotImplementedError
  54. @staticmethod
  55. def set_device(device: _device_t):
  56. raise NotImplementedError
  57. @staticmethod
  58. def maybe_exchange_device(device: int) -> int:
  59. raise NotImplementedError
  60. @staticmethod
  61. def exchange_device(device: int) -> int:
  62. raise NotImplementedError
  63. @staticmethod
  64. def device_count():
  65. raise NotImplementedError
  66. @staticmethod
  67. def is_available() -> bool:
  68. raise NotImplementedError
  69. @staticmethod
  70. def stream(stream: torch.Stream):
  71. raise NotImplementedError
  72. @staticmethod
  73. def current_stream():
  74. raise NotImplementedError
  75. @staticmethod
  76. def set_stream(stream: torch.Stream):
  77. raise NotImplementedError
  78. @staticmethod
  79. def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
  80. raise NotImplementedError
  81. @staticmethod
  82. def get_raw_stream():
  83. raise NotImplementedError
  84. @staticmethod
  85. def synchronize(device: _device_t = None):
  86. raise NotImplementedError
  87. @staticmethod
  88. def get_device_properties(device: _device_t = None):
  89. raise NotImplementedError
  90. @staticmethod
  91. def get_compute_capability(device: _device_t = None):
  92. raise NotImplementedError
  93. class DeviceGuard:
  94. """
  95. This class provides a context manager for device switching. This is a stripped
  96. down version of torch.{device_name}.device.
  97. The context manager changes the current device to the given device index
  98. on entering the context and restores the original device on exiting.
  99. The device is switched using the provided device interface.
  100. """
  101. def __init__(self, device_interface: Type[DeviceInterface], index: Optional[int]):
  102. self.device_interface = device_interface
  103. self.idx = index
  104. self.prev_idx = -1
  105. def __enter__(self):
  106. if self.idx is not None:
  107. self.prev_idx = self.device_interface.exchange_device(self.idx)
  108. def __exit__(self, type: Any, value: Any, traceback: Any):
  109. if self.idx is not None:
  110. self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
  111. return False
  112. class CudaInterface(DeviceInterface):
  113. device = torch.cuda.device
  114. # register Event and Stream class into the backend interface
  115. # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
  116. Event = torch.cuda.Event
  117. Stream = torch.cuda.Stream
  118. class Worker:
  119. @staticmethod
  120. def set_device(device: int):
  121. caching_worker_current_devices["cuda"] = device
  122. @staticmethod
  123. def current_device() -> int:
  124. if "cuda" in caching_worker_current_devices:
  125. return caching_worker_current_devices["cuda"]
  126. return torch.cuda.current_device()
  127. @staticmethod
  128. def get_device_properties(device: _device_t = None):
  129. if device is not None:
  130. if isinstance(device, str):
  131. device = torch.device(device)
  132. assert device.type == "cuda"
  133. if isinstance(device, torch.device):
  134. device = device.index
  135. if device is None:
  136. device = CudaInterface.Worker.current_device()
  137. if "cuda" not in caching_worker_device_properties:
  138. device_prop = [
  139. torch.cuda.get_device_properties(i)
  140. for i in range(torch.cuda.device_count())
  141. ]
  142. caching_worker_device_properties["cuda"] = device_prop
  143. return caching_worker_device_properties["cuda"][device]
  144. current_device = staticmethod(torch.cuda.current_device)
  145. set_device = staticmethod(torch.cuda.set_device)
  146. device_count = staticmethod(torch.cuda.device_count)
  147. stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
  148. current_stream = staticmethod(torch.cuda.current_stream)
  149. set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
  150. _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
  151. synchronize = staticmethod(torch.cuda.synchronize)
  152. get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment]
  153. get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[arg-type]
  154. exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type]
  155. maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type]
  156. # Can be mock patched by @patch decorator.
  157. @staticmethod
  158. def is_available() -> bool:
  159. return torch.cuda.is_available()
  160. @staticmethod
  161. def get_compute_capability(device: _device_t = None):
  162. if torch.version.hip is None:
  163. major, min = torch.cuda.get_device_capability(device)
  164. return major * 10 + min
  165. else:
  166. return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
  167. get_xpu_stream: Optional[Callable[[int], int]]
  168. if torch.xpu._is_compiled():
  169. from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
  170. else:
  171. get_xpu_stream = None
  172. class XpuInterface(DeviceInterface):
  173. device = torch.xpu.device
  174. Event = torch.xpu.Event
  175. Stream = torch.xpu.Stream
  176. class Worker:
  177. @staticmethod
  178. def set_device(device: int):
  179. caching_worker_current_devices["xpu"] = device
  180. @staticmethod
  181. def current_device() -> int:
  182. if "xpu" in caching_worker_current_devices:
  183. return caching_worker_current_devices["xpu"]
  184. return torch.xpu.current_device()
  185. @staticmethod
  186. def get_device_properties(device: _device_t = None):
  187. if device is not None:
  188. if isinstance(device, str):
  189. device = torch.device(device)
  190. assert device.type == "xpu"
  191. if isinstance(device, torch.device):
  192. device = device.index
  193. if device is None:
  194. device = XpuInterface.Worker.current_device()
  195. if "xpu" not in caching_worker_device_properties:
  196. device_prop = [
  197. torch.xpu.get_device_properties(i)
  198. for i in range(torch.xpu.device_count())
  199. ]
  200. caching_worker_device_properties["xpu"] = device_prop
  201. return caching_worker_device_properties["xpu"][device]
  202. current_device = staticmethod(torch.xpu.current_device)
  203. set_device = staticmethod(torch.xpu.set_device)
  204. device_count = staticmethod(torch.xpu.device_count)
  205. stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
  206. current_stream = staticmethod(torch.xpu.current_stream)
  207. set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
  208. _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
  209. synchronize = staticmethod(torch.xpu.synchronize)
  210. get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment]
  211. get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[arg-type]
  212. exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type]
  213. maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type]
  214. # Can be mock patched by @patch decorator.
  215. @staticmethod
  216. def is_available() -> bool:
  217. return torch.xpu.is_available()
  218. @staticmethod
  219. def get_compute_capability(device: _device_t = None):
  220. cc = torch.xpu.get_device_capability(device)
  221. return cc
  222. device_interfaces: Dict[str, Type[DeviceInterface]] = {}
  223. _device_initialized = False
  224. def register_interface_for_device(
  225. device: Union[str, torch.device], device_interface: Type[DeviceInterface]
  226. ):
  227. if isinstance(device, torch.device):
  228. device = str(device)
  229. device_interfaces[device] = device_interface
  230. def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]:
  231. if isinstance(device, torch.device):
  232. device = str(device)
  233. if not _device_initialized:
  234. init_device_reg()
  235. if device in device_interfaces:
  236. return device_interfaces[device]
  237. raise NotImplementedError(f"No interface for device {device}")
  238. def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]:
  239. if not _device_initialized:
  240. init_device_reg()
  241. return device_interfaces.items()
  242. def init_device_reg():
  243. global _device_initialized
  244. register_interface_for_device("cuda", CudaInterface)
  245. for i in range(torch.cuda.device_count()):
  246. register_interface_for_device(f"cuda:{i}", CudaInterface)
  247. register_interface_for_device("xpu", XpuInterface)
  248. for i in range(torch.xpu.device_count()):
  249. register_interface_for_device(f"xpu:{i}", XpuInterface)
  250. _device_initialized = True