Dispatch.h 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. #pragma once
  2. #include <ATen/core/DeprecatedTypeProperties.h>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Half.h>
  6. #include <c10/util/Metaprogramming.h>
  7. #include <c10/util/complex.h>
  8. #include <c10/util/string_view.h>
  9. #ifdef __CUDACC__
  10. #include <cuda.h> // For CUDA_VERSION
  11. #endif
  12. #ifdef TEMPLATE_SELECTIVE_BUILD
  13. #include <ATen/selected_mobile_ops.h>
  14. #else
  15. namespace at {
  16. /**
  17. * The method should_include_kernel_dtype() returns true/false
  18. * based on whether the switching code for a specific dtype should be
  19. * included based on build time constants generated from tracing model
  20. * execution. This method will be implmeneted via code-generation and
  21. * included in this file when code-gen is ready.
  22. */
  23. inline constexpr bool should_include_kernel_dtype(
  24. const char* /*kernel_tag_str*/,
  25. at::ScalarType /*scalar_type*/
  26. ) {
  27. return true;
  28. }
  29. } // namespace at
  30. #endif
  31. /**
  32. * In the Facebook internal build (using BUCK), this macro is enabled by
  33. * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
  34. * binary.
  35. */
  36. #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
  37. namespace at {
  38. namespace detail {
  39. TORCH_API void record_kernel_function_dtype(std::string name);
  40. }
  41. } // namespace at
  42. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
  43. at::detail::record_kernel_function_dtype( \
  44. std::string(NAME) + "$" + toString(enum_type));
  45. #else
  46. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
  47. #endif
  48. #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
  49. do { \
  50. if constexpr (!at::should_include_kernel_dtype( \
  51. at_dispatch_name, enum_type)) { \
  52. AT_ERROR( \
  53. "dtype '", \
  54. toString(enum_type), \
  55. "' not selected for kernel tag ", \
  56. at_dispatch_name); \
  57. } \
  58. } while (0)
  59. #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
  60. case enum_type: { \
  61. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  62. using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
  63. return __VA_ARGS__(); \
  64. }
  65. #define AT_DISPATCH_CASE(enum_type, ...) \
  66. AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
  67. #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
  68. case enum_type: { \
  69. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  70. using scalar_t = scalar_type; \
  71. using underlying_t C10_UNUSED = typename scalar_t::underlying; \
  72. const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
  73. const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
  74. return __VA_ARGS__(); \
  75. }
  76. #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  77. enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
  78. case enum_type: { \
  79. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  80. using scalar_t = scalar_type; \
  81. using underlying_t C10_UNUSED = typename scalar_t::underlying; \
  82. const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
  83. const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
  84. C10_UNUSED int bit_width = bitwidth; \
  85. C10_UNUSED int64_t quant_min = qmin; \
  86. C10_UNUSED int64_t quant_max = qmax; \
  87. return __VA_ARGS__(); \
  88. }
  89. namespace detail {
  90. inline at::ScalarType scalar_type(at::ScalarType s) {
  91. return s;
  92. }
  93. C10_DEPRECATED_MESSAGE(
  94. "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
  95. "pass an at::ScalarType instead")
  96. inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
  97. return t.scalarType();
  98. }
  99. C10_DEPRECATED_MESSAGE(
  100. "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
  101. "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
  102. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
  103. C10_DEPRECATED_MESSAGE(
  104. "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
  105. "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
  106. "instead")
  107. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
  108. } // namespace detail
  109. // The AT_DISPATCH_* family of macros provides the ability to
  110. // conveniently generate specializations of a kernel over all of the
  111. // dtypes we care about in PyTorch. We call it "dispatch" because
  112. // we are "dispatching" to the correct, dtype-specific kernel.
  113. //
  114. // A standard usage looks like:
  115. //
  116. // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
  117. // // Your code here, with 'scalar_t' now defined to
  118. // // be the dtype in question
  119. // });
  120. //
  121. // There are many variations of this macro, so it's important to
  122. // understand exactly /which/ dtypes you want to get instantiated, as
  123. // well as what the "default" set is.
  124. //
  125. // The default set of dtypes that are instantiated (e.g., by
  126. // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
  127. // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
  128. // but NOT booleans (bool), half-precision floats (Half) or
  129. // complex number (c10::complex<float>, c10::complex<double>).
  130. // This "cut" is somewhat historical (the default types are the
  131. // ones that TH historically supported), but it also reflects the
  132. // fact that the non-default types are "poorly" behaved (booleans
  133. // are NOT integers mod 2, half precision operations ~essentially
  134. // don't exist on CPU, complex numbers are an experimental application).
  135. //
  136. // Here are the questions you should generally ask to decide which
  137. // dispatch you want:
  138. //
  139. // 1. Is this an integral or floating point specific operation?
  140. // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
  141. //
  142. // 2. Should half be supported? (If you're on CPU, the answer is almost
  143. // definitely no. If you do want support, use one of the AND_HALF
  144. // macros)
  145. //
  146. // Much rarer situations:
  147. //
  148. // 3. Should bool be supported? (You often have to write your kernel
  149. // differently if arithmetic operations are involved.) If so,
  150. // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
  151. //
  152. // 4. Should complex be supported? The answer is almost always no,
  153. // unless you are working on "generic" code that should work on
  154. // all dtypes.
  155. //
  156. // Parameters:
  157. // -----------
  158. //
  159. // 1. The NAME argument is a "tag" that is used to trace and then
  160. // conditionally compile fragments of the case statements such
  161. // that the kernel functions are specialized only for the dtypes
  162. // that are needed. The NAME parameter *must* be a build time
  163. // const char* (can't be std::string, etc...)
  164. //
  165. // Please ensure that the NAME is unique for every implementation
  166. // or you run the risk of over-including code for the kernel
  167. // functions. There is no risk of missing out on any code, so
  168. // it's mostly a risk of a Type-2 error, and not a Type-1 error.
  169. //
  170. // Switch-like syntax:
  171. // -------------------
  172. // There is also a switch-case like syntax which is useful if a kernel
  173. // needs to be specialized for particular scalar types
  174. //
  175. // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
  176. // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
  177. // op_integral<scalar_t>(iter);
  178. // })
  179. // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
  180. // op_floating<scalar_t>(iter);
  181. // })
  182. // AT_DISPATCH_CASE(kBool, [&] {
  183. // op_bool(iter);
  184. // })
  185. // );
  186. //
  187. // For each AT_DISPATCH_FOO macro, there is a corresponding
  188. // AT_DISPATCH_CASE_FOO macro which can be used inside of an
  189. // AT_DISPATCH_SWITCH block.
  190. // NB: the the_type variable is not used, but we have kept it for
  191. // backwards compatibility. It's probably not used by anyone though;
  192. // but we're just being safe (and it doesn't hurt.) Note we must
  193. // use it to shut up warnings about unused store.
  194. #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
  195. [&] { \
  196. const auto& the_type = TYPE; \
  197. constexpr const char* at_dispatch_name = NAME; \
  198. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  199. at::ScalarType _st = ::detail::scalar_type(the_type); \
  200. RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
  201. switch (_st) { \
  202. __VA_ARGS__ \
  203. default: \
  204. AT_ERROR( \
  205. '"', \
  206. at_dispatch_name, \
  207. "\" not implemented for '", \
  208. toString(_st), \
  209. "'"); \
  210. } \
  211. }()
  212. #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
  213. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  214. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
  215. #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  216. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  217. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
  218. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  219. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  220. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
  221. #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
  222. AT_DISPATCH_SWITCH( \
  223. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
  224. #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
  225. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
  226. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
  227. #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
  228. AT_DISPATCH_SWITCH( \
  229. TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
  230. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
  231. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  232. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  233. #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  234. AT_DISPATCH_SWITCH( \
  235. TYPE, \
  236. NAME, \
  237. AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  238. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  239. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  240. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  241. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  242. #define AT_DISPATCH_FLOATING_TYPES_AND2( \
  243. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  244. AT_DISPATCH_SWITCH( \
  245. TYPE, \
  246. NAME, \
  247. AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
  248. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  249. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
  250. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  251. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  252. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  253. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  254. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  255. #define AT_DISPATCH_FLOATING_TYPES_AND3( \
  256. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  257. AT_DISPATCH_SWITCH( \
  258. TYPE, \
  259. NAME, \
  260. AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
  261. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  262. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
  263. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  264. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  265. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  266. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  267. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  268. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  269. #define AT_DISPATCH_FLOATING_TYPES_AND4( \
  270. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  271. AT_DISPATCH_SWITCH( \
  272. TYPE, \
  273. NAME, \
  274. AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
  275. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  276. #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
  277. AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
  278. AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
  279. #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
  280. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
  281. #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
  282. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
  283. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  284. #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  285. AT_DISPATCH_SWITCH( \
  286. TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  287. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
  288. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  289. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  290. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
  291. AT_DISPATCH_SWITCH( \
  292. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
  293. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
  294. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  295. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  296. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
  297. SCALARTYPE, TYPE, NAME, ...) \
  298. AT_DISPATCH_SWITCH( \
  299. TYPE, \
  300. NAME, \
  301. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
  302. SCALARTYPE, __VA_ARGS__))
  303. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  304. SCALARTYPE1, SCALARTYPE2, ...) \
  305. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  306. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  307. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  308. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
  309. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  310. AT_DISPATCH_SWITCH( \
  311. TYPE, \
  312. NAME, \
  313. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  314. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  315. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  316. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  317. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  318. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  319. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  320. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  321. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
  322. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  323. AT_DISPATCH_SWITCH( \
  324. TYPE, \
  325. NAME, \
  326. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  327. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  328. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
  329. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  330. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  331. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  332. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  333. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  334. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  335. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
  336. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  337. AT_DISPATCH_SWITCH( \
  338. TYPE, \
  339. NAME, \
  340. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
  341. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  342. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
  343. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
  344. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  345. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  346. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  347. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  348. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  349. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
  350. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
  351. SCALARTYPE1, \
  352. SCALARTYPE2, \
  353. SCALARTYPE3, \
  354. SCALARTYPE4, \
  355. SCALARTYPE5, \
  356. TYPE, \
  357. NAME, \
  358. ...) \
  359. AT_DISPATCH_SWITCH( \
  360. TYPE, \
  361. NAME, \
  362. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
  363. SCALARTYPE1, \
  364. SCALARTYPE2, \
  365. SCALARTYPE3, \
  366. SCALARTYPE4, \
  367. SCALARTYPE5, \
  368. __VA_ARGS__))
  369. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
  370. SCALARTYPE1, \
  371. SCALARTYPE2, \
  372. SCALARTYPE3, \
  373. SCALARTYPE4, \
  374. SCALARTYPE5, \
  375. SCALARTYPE6, \
  376. ...) \
  377. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  378. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  379. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  380. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  381. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  382. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  383. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
  384. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
  385. SCALARTYPE1, \
  386. SCALARTYPE2, \
  387. SCALARTYPE3, \
  388. SCALARTYPE4, \
  389. SCALARTYPE5, \
  390. SCALARTYPE6, \
  391. TYPE, \
  392. NAME, \
  393. ...) \
  394. AT_DISPATCH_SWITCH( \
  395. TYPE, \
  396. NAME, \
  397. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
  398. SCALARTYPE1, \
  399. SCALARTYPE2, \
  400. SCALARTYPE3, \
  401. SCALARTYPE4, \
  402. SCALARTYPE5, \
  403. SCALARTYPE6, \
  404. __VA_ARGS__))
  405. #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
  406. AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
  407. AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
  408. AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
  409. AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
  410. AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
  411. #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
  412. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
  413. #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
  414. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  415. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  416. #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  417. AT_DISPATCH_SWITCH( \
  418. TYPE, \
  419. NAME, \
  420. AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  421. #define AT_DISPATCH_CASE_ALL_TYPES(...) \
  422. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  423. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
  424. #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  425. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
  426. #define AT_DISPATCH_CASE_QINT_TYPES(...) \
  427. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  428. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
  429. AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
  430. #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
  431. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
  432. #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
  433. AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
  434. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  435. #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  436. AT_DISPATCH_SWITCH( \
  437. TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  438. #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
  439. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  440. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
  441. #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
  442. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
  443. #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
  444. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  445. at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
  446. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  447. at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
  448. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  449. at::kQInt32, \
  450. at::qint32, \
  451. CHAR_BIT * sizeof(int), \
  452. INT_MIN, \
  453. INT_MAX, \
  454. __VA_ARGS__) \
  455. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  456. at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
  457. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  458. at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
  459. #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
  460. AT_DISPATCH_SWITCH( \
  461. TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
  462. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
  463. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  464. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  465. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
  466. AT_DISPATCH_SWITCH( \
  467. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
  468. #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
  469. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  470. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  471. #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  472. AT_DISPATCH_SWITCH( \
  473. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  474. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
  475. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  476. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  477. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
  478. AT_DISPATCH_SWITCH( \
  479. TYPE, \
  480. NAME, \
  481. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
  482. #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  483. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  484. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  485. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  486. #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  487. AT_DISPATCH_SWITCH( \
  488. TYPE, \
  489. NAME, \
  490. AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  491. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  492. SCALARTYPE1, SCALARTYPE2, ...) \
  493. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  494. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  495. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  496. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
  497. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  498. AT_DISPATCH_SWITCH( \
  499. TYPE, \
  500. NAME, \
  501. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  502. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  503. #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  504. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  505. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  506. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  507. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  508. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  509. #define AT_DISPATCH_ALL_TYPES_AND3( \
  510. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  511. AT_DISPATCH_SWITCH( \
  512. TYPE, \
  513. NAME, \
  514. AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  515. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  516. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  517. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  518. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  519. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  520. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  521. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  522. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
  523. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  524. AT_DISPATCH_SWITCH( \
  525. TYPE, \
  526. NAME, \
  527. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  528. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  529. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  530. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  531. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  532. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  533. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  534. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  535. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  536. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
  537. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  538. AT_DISPATCH_SWITCH( \
  539. TYPE, \
  540. NAME, \
  541. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  542. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  543. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
  544. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
  545. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  546. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  547. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  548. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  549. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  550. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
  551. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
  552. SCALARTYPE1, \
  553. SCALARTYPE2, \
  554. SCALARTYPE3, \
  555. SCALARTYPE4, \
  556. SCALARTYPE5, \
  557. TYPE, \
  558. NAME, \
  559. ...) \
  560. AT_DISPATCH_SWITCH( \
  561. TYPE, \
  562. NAME, \
  563. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
  564. SCALARTYPE1, \
  565. SCALARTYPE2, \
  566. SCALARTYPE3, \
  567. SCALARTYPE4, \
  568. SCALARTYPE5, \
  569. __VA_ARGS__))
  570. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
  571. SCALARTYPE1, \
  572. SCALARTYPE2, \
  573. SCALARTYPE3, \
  574. SCALARTYPE4, \
  575. SCALARTYPE5, \
  576. SCALARTYPE6, \
  577. ...) \
  578. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  579. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  580. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  581. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  582. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  583. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  584. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
  585. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
  586. SCALARTYPE1, \
  587. SCALARTYPE2, \
  588. SCALARTYPE3, \
  589. SCALARTYPE4, \
  590. SCALARTYPE5, \
  591. SCALARTYPE6, \
  592. TYPE, \
  593. NAME, \
  594. ...) \
  595. AT_DISPATCH_SWITCH( \
  596. TYPE, \
  597. NAME, \
  598. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
  599. SCALARTYPE1, \
  600. SCALARTYPE2, \
  601. SCALARTYPE3, \
  602. SCALARTYPE4, \
  603. SCALARTYPE5, \
  604. SCALARTYPE6, \
  605. __VA_ARGS__))
  606. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
  607. SCALARTYPE1, \
  608. SCALARTYPE2, \
  609. SCALARTYPE3, \
  610. SCALARTYPE4, \
  611. SCALARTYPE5, \
  612. SCALARTYPE6, \
  613. SCALARTYPE7, \
  614. ...) \
  615. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  616. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  617. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  618. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  619. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  620. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  621. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
  622. AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
  623. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
  624. SCALARTYPE1, \
  625. SCALARTYPE2, \
  626. SCALARTYPE3, \
  627. SCALARTYPE4, \
  628. SCALARTYPE5, \
  629. SCALARTYPE6, \
  630. SCALARTYPE7, \
  631. TYPE, \
  632. NAME, \
  633. ...) \
  634. AT_DISPATCH_SWITCH( \
  635. TYPE, \
  636. NAME, \
  637. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
  638. SCALARTYPE1, \
  639. SCALARTYPE2, \
  640. SCALARTYPE3, \
  641. SCALARTYPE4, \
  642. SCALARTYPE5, \
  643. SCALARTYPE6, \
  644. SCALARTYPE7, \
  645. __VA_ARGS__))
  646. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
  647. SCALARTYPE1, \
  648. SCALARTYPE2, \
  649. SCALARTYPE3, \
  650. SCALARTYPE4, \
  651. SCALARTYPE5, \
  652. SCALARTYPE6, \
  653. SCALARTYPE7, \
  654. SCALARTYPE8, \
  655. ...) \
  656. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  657. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  658. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  659. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  660. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  661. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  662. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
  663. AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
  664. AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
  665. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
  666. SCALARTYPE1, \
  667. SCALARTYPE2, \
  668. SCALARTYPE3, \
  669. SCALARTYPE4, \
  670. SCALARTYPE5, \
  671. SCALARTYPE6, \
  672. SCALARTYPE7, \
  673. SCALARTYPE8, \
  674. TYPE, \
  675. NAME, \
  676. ...) \
  677. AT_DISPATCH_SWITCH( \
  678. TYPE, \
  679. NAME, \
  680. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
  681. SCALARTYPE1, \
  682. SCALARTYPE2, \
  683. SCALARTYPE3, \
  684. SCALARTYPE4, \
  685. SCALARTYPE5, \
  686. SCALARTYPE6, \
  687. SCALARTYPE7, \
  688. SCALARTYPE8, \
  689. __VA_ARGS__))
  690. #define AT_DISPATCH_CASE_BIT_TYPES(...) \
  691. AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
  692. AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
  693. AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
  694. AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
  695. AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
  696. #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
  697. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
  698. #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
  699. AT_DISPATCH_SWITCH( \
  700. TYPE, \
  701. NAME, \
  702. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  703. at::ScalarType::Int, index_t, __VA_ARGS__) \
  704. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  705. at::ScalarType::Long, index_t, __VA_ARGS__))
  706. // ----------------------------------------------------------------------------
  707. // DEPRECATED MACROS, DON'T USE THESE
  708. // ----------------------------------------------------------------------------
  709. #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
  710. detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
  711. AT_DISPATCH_SWITCH( \
  712. TYPE, \
  713. NAME, \
  714. AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))