DispatchStub.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. #pragma once
  2. #include <c10/core/DeviceType.h>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Array.h>
  5. #include <atomic>
  6. #include <utility>
  7. #include <variant>
  8. // Implements instruction set specific function dispatch.
  9. //
  10. // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
  11. // compiled multiple times with different compiler flags (e.g. -mavx2). A
  12. // DispatchStub contains a table of function pointers for a kernel. At runtime,
  13. // the fastest available kernel is chosen based on the features reported by
  14. // cpuinfo.
  15. //
  16. // Example:
  17. //
  18. // In native/MyKernel.h:
  19. // using fn_type = void(*)(const Tensor& x);
  20. // DECLARE_DISPATCH(fn_type, stub);
  21. //
  22. // In native/MyKernel.cpp
  23. // DEFINE_DISPATCH(stub);
  24. //
  25. // In native/cpu/MyKernel.cpp:
  26. // namespace {
  27. // // use anonymous namespace so that different cpu versions won't conflict
  28. // void kernel(const Tensor& x) { ... }
  29. // }
  30. // REGISTER_DISPATCH(stub, &kernel);
  31. //
  32. // To call:
  33. // stub(kCPU, tensor);
  34. //
  35. // TODO: CPU instruction set selection should be folded into whatever
  36. // the main dispatch mechanism is.
  37. //
  38. // Supported device types for registration:
  39. // - CPU: Central Processing Unit
  40. // - CUDA: NVIDIA GPUs
  41. // - HIP: AMD GPUs
  42. // - MPS: Apple Silicon GPUs (Metal Performance Shaders)
  43. // - PrivateUse1: Reserved for private/custom device types
  44. //
  45. // If you want to update the list of supported devices, add a new dispatch_ptr
  46. // member in DispatchStubImpl.h and update the get_call_ptr switch.
  47. // As well you will need to update the inlined list in 'is_device_supported`
  48. //
  49. //
  50. // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
  51. C10_CLANG_DIAGNOSTIC_PUSH()
  52. C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
  53. namespace at::native {
  54. enum class CPUCapability {
  55. DEFAULT = 0,
  56. #if defined(HAVE_VSX_CPU_DEFINITION)
  57. VSX = 1,
  58. #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
  59. ZVECTOR = 1,
  60. #else
  61. AVX2 = 1,
  62. AVX512 = 2,
  63. #endif
  64. NUM_OPTIONS
  65. };
  66. // Enum for error types
  67. enum class ErrorType {
  68. MissingDeviceKernel,
  69. DeviceNotSupported
  70. };
  71. // Alias for the return type using std::variant
  72. using DispatchResult = std::variant<void*, ErrorType>;
  73. CPUCapability get_cpu_capability();
  74. template <typename FnPtr, typename T>
  75. struct DispatchStub;
  76. /**
  77. * The sole purpose of this class is to outline methods that don't need to be
  78. * specialized or otherwise inlined and duplicated (by the compiler due to
  79. * template expansion), since it causes size bloat if there are a significant
  80. * number of specialization of the DispatchStub<> class.
  81. */
  82. struct TORCH_API DispatchStubImpl {
  83. // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
  84. // pointer for a given device type. If the call pointer is not found,
  85. // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
  86. // The main difference between try_get_call_ptr() and get_call_ptr() is that
  87. // try_get_call_ptr() will return the ErrorType and not raise an exception.
  88. DispatchResult try_get_call_ptr(
  89. c10::DeviceType device_type
  90. , void *DEFAULT
  91. #ifdef HAVE_AVX512_CPU_DEFINITION
  92. , void *AVX512
  93. #endif
  94. #ifdef HAVE_AVX2_CPU_DEFINITION
  95. , void *AVX2
  96. #endif
  97. #ifdef HAVE_VSX_CPU_DEFINITION
  98. , void *VSX
  99. #endif
  100. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  101. , void *ZVECTOR
  102. #endif
  103. );
  104. // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
  105. // raise an exception.
  106. DispatchResult try_choose_cpu_impl(
  107. void *DEFAULT
  108. #ifdef HAVE_AVX512_CPU_DEFINITION
  109. , void *AVX512
  110. #endif
  111. #ifdef HAVE_AVX2_CPU_DEFINITION
  112. , void *AVX2
  113. #endif
  114. #ifdef HAVE_VSX_CPU_DEFINITION
  115. , void *VSX
  116. #endif
  117. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  118. , void *ZVECTOR
  119. #endif
  120. );
  121. void* get_call_ptr(
  122. c10::DeviceType device_type
  123. , void *DEFAULT
  124. #ifdef HAVE_AVX512_CPU_DEFINITION
  125. , void *AVX512
  126. #endif
  127. #ifdef HAVE_AVX2_CPU_DEFINITION
  128. , void *AVX2
  129. #endif
  130. #ifdef HAVE_VSX_CPU_DEFINITION
  131. , void *VSX
  132. #endif
  133. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  134. , void *ZVECTOR
  135. #endif
  136. );
  137. /**
  138. * The CPU Dispatch actual method is chosen in decreasing order of preference by
  139. * DispatchStubImpl::choose_cpu_impl() in case none is found by
  140. * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
  141. */
  142. void* choose_cpu_impl(
  143. void *DEFAULT
  144. #ifdef HAVE_AVX512_CPU_DEFINITION
  145. , void *AVX512
  146. #endif
  147. #ifdef HAVE_AVX2_CPU_DEFINITION
  148. , void *AVX2
  149. #endif
  150. #ifdef HAVE_VSX_CPU_DEFINITION
  151. , void *VSX
  152. #endif
  153. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  154. , void *ZVECTOR
  155. #endif
  156. );
  157. // Fixing dispatch error in Windows debug builds.
  158. // See https://github.com/pytorch/pytorch/issues/22681 for more details.
  159. #if defined(_MSC_VER) && defined(_DEBUG)
  160. std::atomic<void*> cpu_dispatch_ptr;
  161. void* cuda_dispatch_ptr;
  162. void* hip_dispatch_ptr;
  163. void* mps_dispatch_ptr;
  164. void* privateuse1_dispatch_ptr;
  165. #else
  166. std::atomic<void*> cpu_dispatch_ptr{nullptr};
  167. void* cuda_dispatch_ptr = nullptr;
  168. void* hip_dispatch_ptr = nullptr;
  169. void* mps_dispatch_ptr = nullptr;
  170. void* privateuse1_dispatch_ptr = nullptr;
  171. #endif
  172. };
  173. template <typename rT, typename T, typename... Args>
  174. struct DispatchStub<rT (*)(Args...), T> {
  175. using FnPtr = rT (*) (Args...);
  176. DispatchStub() = default;
  177. DispatchStub(const DispatchStub&) = delete;
  178. DispatchStub& operator=(const DispatchStub&) = delete;
  179. private:
  180. FnPtr get_call_ptr(const c10::DeviceType device_type) {
  181. return reinterpret_cast<FnPtr>(
  182. impl.get_call_ptr(device_type
  183. , reinterpret_cast<void*>(DEFAULT)
  184. #ifdef HAVE_AVX512_CPU_DEFINITION
  185. , reinterpret_cast<void*>(AVX512)
  186. #endif
  187. #ifdef HAVE_AVX2_CPU_DEFINITION
  188. , reinterpret_cast<void*>(AVX2)
  189. #endif
  190. #ifdef HAVE_VSX_CPU_DEFINITION
  191. , reinterpret_cast<void*>(VSX)
  192. #endif
  193. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  194. , reinterpret_cast<void*>(ZVECTOR)
  195. #endif
  196. )
  197. );
  198. }
  199. public:
  200. template <typename... ArgTypes>
  201. rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
  202. FnPtr call_ptr = get_call_ptr(device_type);
  203. return (*call_ptr)(std::forward<ArgTypes>(args)...);
  204. }
  205. void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
  206. impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  207. }
  208. void set_hip_dispatch_ptr(FnPtr fn_ptr) {
  209. impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  210. }
  211. void set_mps_dispatch_ptr(FnPtr fn_ptr) {
  212. impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  213. }
  214. void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
  215. impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  216. }
  217. // Returns true if the dispatcher has a kernel registered for this device
  218. // type.
  219. bool is_device_supported(const c10::DeviceType device_type) {
  220. auto result = impl.try_get_call_ptr(device_type
  221. , reinterpret_cast<void*>(DEFAULT)
  222. #ifdef HAVE_AVX512_CPU_DEFINITION
  223. , reinterpret_cast<void*>(AVX512)
  224. #endif
  225. #ifdef HAVE_AVX2_CPU_DEFINITION
  226. , reinterpret_cast<void*>(AVX2)
  227. #endif
  228. #ifdef HAVE_VSX_CPU_DEFINITION
  229. , reinterpret_cast<void*>(VSX)
  230. #endif
  231. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  232. , reinterpret_cast<void*>(ZVECTOR)
  233. #endif
  234. );
  235. if (std::holds_alternative<ErrorType>(result)){
  236. return false;
  237. }
  238. return true;
  239. };
  240. static TORCH_API FnPtr DEFAULT;
  241. #ifdef HAVE_AVX512_CPU_DEFINITION
  242. static TORCH_API FnPtr AVX512;
  243. #endif
  244. #ifdef HAVE_AVX2_CPU_DEFINITION
  245. static TORCH_API FnPtr AVX2;
  246. #endif
  247. #ifdef HAVE_VSX_CPU_DEFINITION
  248. static TORCH_API FnPtr VSX;
  249. #endif
  250. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  251. static TORCH_API FnPtr ZVECTOR;
  252. #endif
  253. private:
  254. DispatchStubImpl impl;
  255. };
  256. namespace {
  257. template <typename DispatchStub>
  258. struct RegisterCUDADispatch {
  259. RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  260. stub.set_cuda_dispatch_ptr(value);
  261. }
  262. };
  263. template <typename DispatchStub>
  264. struct RegisterMPSDispatch {
  265. RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  266. stub.set_mps_dispatch_ptr(value);
  267. }
  268. };
  269. template <typename DispatchStub>
  270. struct RegisterHIPDispatch {
  271. RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  272. // TODO: make this point at hip_dispatch_ptr
  273. stub.set_cuda_dispatch_ptr(value);
  274. }
  275. };
  276. template <typename DispatchStub>
  277. struct RegisterPRIVATEUSE1Dispatch {
  278. RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  279. stub.set_privateuse1_dispatch_ptr(value);
  280. }
  281. };
  282. } // anonymous namespace
  283. // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
  284. // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
  285. // adding parentheses and using helper struct to get rid of the parentheses, do
  286. // not work with MSVC. So do a `using`-declaration if you need to pass in such
  287. // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
  288. #define DECLARE_DISPATCH(fn, name) \
  289. struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
  290. name##_DECLARE_DISPATCH_type() = default; \
  291. name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
  292. name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
  293. }; \
  294. extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
  295. #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
  296. #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
  297. template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
  298. #ifdef HAVE_AVX512_CPU_DEFINITION
  299. #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
  300. #else
  301. #define REGISTER_AVX512_DISPATCH(name, fn)
  302. #endif
  303. #ifdef HAVE_AVX2_CPU_DEFINITION
  304. #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
  305. #else
  306. #define REGISTER_AVX2_DISPATCH(name, fn)
  307. #endif
  308. #ifdef HAVE_VSX_CPU_DEFINITION
  309. #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
  310. #else
  311. #define REGISTER_VSX_DISPATCH(name, fn)
  312. #endif
  313. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  314. #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
  315. #else
  316. #define REGISTER_ZVECTOR_DISPATCH(name, fn)
  317. #endif
  318. // Macro to register the same kernel for all CPU arch types. This is useful
  319. // if a kernel does not benefit from being recompiled across different arch types.
  320. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
  321. REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
  322. REGISTER_AVX512_DISPATCH(name, fn) \
  323. REGISTER_AVX2_DISPATCH(name, fn) \
  324. REGISTER_VSX_DISPATCH(name, fn) \
  325. REGISTER_ZVECTOR_DISPATCH(name, fn)
  326. #define REGISTER_NO_CPU_DISPATCH(name) \
  327. REGISTER_ALL_CPU_DISPATCH(name, nullptr)
  328. #define REGISTER_CUDA_DISPATCH(name, fn) \
  329. static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  330. #define REGISTER_HIP_DISPATCH(name, fn) \
  331. static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  332. #define REGISTER_MPS_DISPATCH(name, fn) \
  333. static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  334. #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
  335. static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  336. // NB: This macro must be used in an actual 'cu' file; if you try using
  337. // it from a 'cpp' file it will not work!
  338. #if defined(__CUDACC__)
  339. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  340. #elif defined(__HIPCC__)
  341. // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
  342. // is HIP in the PyTorch HIPify build.
  343. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  344. // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
  345. #elif defined(__OBJC__) && defined(USE_MPS)
  346. // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
  347. #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
  348. #elif defined(CPU_CAPABILITY)
  349. // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
  350. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
  351. #ifdef CPU_CAPABILITY_AVX512
  352. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
  353. #else
  354. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  355. #endif
  356. #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  357. #endif
  358. } // namespace at::native
  359. C10_CLANG_DIAGNOSTIC_POP()