IndexKernels.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. #pragma once
  2. namespace at::mps {
  3. static const char * indexing_metal_shaders = R"INDEX_METAL(
  4. #include <metal_stdlib>
  5. #include <metal_atomic>
  6. using namespace metal;
  7. #if __METAL_VERSION__ < 300
  8. struct IndexAB {
  9. // Allow up to 16 indices
  10. metal::array<constant void *, 16> indexArray [[ id(0) ]];
  11. };
  12. #else
  13. struct IndexAB {
  14. constant int64_t* indexArray;
  15. };
  16. #endif
  17. template<typename T, typename OffsetsT>
  18. kernel void index_select(
  19. #if __METAL_VERSION__ >= 300
  20. constant IndexAB * indexAB [[buffer(0)]],
  21. #else
  22. constant IndexAB & indexAB [[buffer(0)]],
  23. #endif
  24. constant void * indexSizes [[buffer(1)]],
  25. constant void * indexStrides [[buffer(2)]],
  26. constant OffsetsT * offsets [[buffer(3)]],
  27. constant void * inputData [[buffer(4)]],
  28. device void * outputData [[buffer(5)]],
  29. constant uint32_t & num_indices [[buffer(6)]],
  30. uint thread_index [[thread_position_in_grid]]) {
  31. constant int64_t * index_sizes = (constant int64_t *)indexSizes;
  32. constant int64_t * index_strides = (constant int64_t *)indexStrides;
  33. int64_t offset = 0;
  34. for (uint32_t i = 0; i < num_indices; i++) {
  35. #if __METAL_VERSION__ >= 300
  36. constant int64_t* indexArray = indexAB[i].indexArray;
  37. #else
  38. constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
  39. #endif
  40. int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
  41. if (index < 0) {
  42. index += index_sizes[i];
  43. }
  44. offset += index * index_strides[i];
  45. }
  46. device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
  47. constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
  48. *out = *in;
  49. }
  50. template<typename T, typename OffsetsT>
  51. void index_put_impl(
  52. #if __METAL_VERSION__ >= 300
  53. constant IndexAB * indexAB,
  54. #else
  55. constant IndexAB & indexAB,
  56. #endif
  57. constant int64_t * index_sizes,
  58. constant int64_t * index_strides,
  59. constant OffsetsT * offsets,
  60. constant void * inputData,
  61. device void * outputData,
  62. constant uint32_t & num_indices,
  63. uint thread_index) {
  64. int64_t offset = 0;
  65. for (uint32_t i = 0; i < num_indices; i++) {
  66. #if __METAL_VERSION__ >= 300
  67. constant int64_t* indexArray = indexAB[i].indexArray;
  68. #else
  69. constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
  70. #endif
  71. int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
  72. if (index < 0) {
  73. index += index_sizes[i];
  74. }
  75. offset += index * index_strides[i];
  76. }
  77. device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
  78. constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
  79. *out = *in;
  80. }
  81. template<typename T, typename OffsetsT>
  82. kernel void index_put_serial(
  83. #if __METAL_VERSION__ >= 300
  84. constant IndexAB * indexAB [[buffer(0)]],
  85. #else
  86. constant IndexAB & indexAB [[buffer(0)]],
  87. #endif
  88. constant void * indexSizes [[buffer(1)]],
  89. constant void * indexStrides [[buffer(2)]],
  90. constant OffsetsT * offsets [[buffer(3)]],
  91. constant void * inputData [[buffer(4)]],
  92. device void * outputData [[buffer(5)]],
  93. constant uint32_t & num_indices [[buffer(6)]],
  94. constant uint * numIters [[buffer(7)]],
  95. uint thread_index [[thread_position_in_grid]]) {
  96. constant int64_t * index_sizes = (constant int64_t *)indexSizes;
  97. constant int64_t * index_strides = (constant int64_t *)indexStrides;
  98. for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
  99. index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
  100. }
  101. }
  102. template<typename T, typename OffsetsT>
  103. kernel void index_put(
  104. #if __METAL_VERSION__ >= 300
  105. constant IndexAB * indexAB [[buffer(0)]],
  106. #else
  107. constant IndexAB & indexAB [[buffer(0)]],
  108. #endif
  109. constant void * indexSizes [[buffer(1)]],
  110. constant void * indexStrides [[buffer(2)]],
  111. constant OffsetsT * offsets [[buffer(3)]],
  112. constant void * inputData [[buffer(4)]],
  113. device void * outputData [[buffer(5)]],
  114. constant uint32_t & num_indices [[buffer(6)]],
  115. uint thread_index [[thread_position_in_grid]]) {
  116. constant int64_t * index_sizes = (constant int64_t *)indexSizes;
  117. constant int64_t * index_strides = (constant int64_t *)indexStrides;
  118. index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
  119. }
  120. #if __METAL_VERSION__ < 300
  121. #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
  122. template \
  123. [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
  124. kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
  125. constant IndexAB & indexAB [[buffer(0)]], \
  126. constant void * indexSizes [[buffer(1)]], \
  127. constant void * indexStrides [[buffer(2)]], \
  128. constant IDX_DTYPE * offsets [[buffer(3)]], \
  129. constant void * inputData [[buffer(4)]], \
  130. device void * outputData [[buffer(5)]], \
  131. constant uint32_t & num_indices [[buffer(6)]], \
  132. uint thread_index [[thread_position_in_grid]]);
  133. #else
  134. #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
  135. template \
  136. [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
  137. kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
  138. constant IndexAB * indexAB [[buffer(0)]], \
  139. constant void * indexSizes [[buffer(1)]], \
  140. constant void * indexStrides [[buffer(2)]], \
  141. constant IDX_DTYPE * offsets [[buffer(3)]], \
  142. constant void * inputData [[buffer(4)]], \
  143. device void * outputData [[buffer(5)]], \
  144. constant uint32_t & num_indices [[buffer(6)]], \
  145. uint thread_index [[thread_position_in_grid]]);
  146. #endif
  147. #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
  148. REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
  149. REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
  150. REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
  151. REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
  152. REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
  153. REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
  154. REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
  155. REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
  156. REGISTER_INDEX_OP_ALL_DTYPES(select);
  157. REGISTER_INDEX_OP_ALL_DTYPES(put);
  158. #if __METAL_VERSION__ < 300
  159. #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
  160. template \
  161. [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
  162. kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
  163. constant IndexAB & indexAB [[buffer(0)]], \
  164. constant void * indexSizes [[buffer(1)]], \
  165. constant void * indexStrides [[buffer(2)]], \
  166. constant IDX_DTYPE * offsets [[buffer(3)]], \
  167. constant void * inputData [[buffer(4)]], \
  168. device void * outputData [[buffer(5)]], \
  169. constant uint32_t & num_indices [[buffer(6)]], \
  170. constant uint * numIters [[buffer(7)]], \
  171. uint thread_index [[thread_position_in_grid]]);
  172. #else
  173. #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
  174. template \
  175. [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
  176. kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
  177. constant IndexAB * indexAB [[buffer(0)]], \
  178. constant void * indexSizes [[buffer(1)]], \
  179. constant void * indexStrides [[buffer(2)]], \
  180. constant IDX_DTYPE * offsets [[buffer(3)]], \
  181. constant void * inputData [[buffer(4)]], \
  182. device void * outputData [[buffer(5)]], \
  183. constant uint32_t & num_indices [[buffer(6)]], \
  184. constant uint * numIters [[buffer(7)]], \
  185. uint thread_index [[thread_position_in_grid]]);
  186. #endif
  187. #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
  188. REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
  189. REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
  190. REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
  191. REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
  192. REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
  193. REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
  194. REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
  195. REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
  196. REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
  197. template<typename StridesT, typename DataT>
  198. kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
  199. device DataT * data_offsets [[buffer(1)]],
  200. constant uint * iter_shape [[buffer(2)]],
  201. constant uint & num_dimensions [[buffer(3)]],
  202. uint thread_index [[thread_position_in_grid]]) {
  203. data_offsets[thread_index] = 0;
  204. uint32_t idx = thread_index;
  205. for (uint32_t dim = 0; dim < num_dimensions; dim++) {
  206. uint32_t remainder = idx % iter_shape[dim];
  207. idx /= iter_shape[dim];
  208. data_offsets[thread_index] += remainder * DataT(strides[dim]);
  209. }
  210. }
  211. template
  212. [[host_name("kernel_index_offsets_32")]]
  213. kernel void kernel_index_offsets<packed_uint3, uint3>(
  214. constant packed_uint3 * strides [[buffer(0)]],
  215. device uint3 * data_offsets [[buffer(1)]],
  216. constant uint * iter_shape [[buffer(2)]],
  217. constant uint & num_dimensions [[buffer(3)]],
  218. uint thread_index [[thread_position_in_grid]]);
  219. template
  220. [[host_name("kernel_index_offsets_64")]]
  221. kernel void kernel_index_offsets<packed_uint3, ulong3>(
  222. constant packed_uint3 * strides [[buffer(0)]],
  223. device ulong3 * data_offsets [[buffer(1)]],
  224. constant uint * iter_shape [[buffer(2)]],
  225. constant uint & num_dimensions [[buffer(3)]],
  226. uint thread_index [[thread_position_in_grid]]);
  227. template<typename T, typename E, typename OffsetsT>
  228. kernel void index_put_accumulate_native_dtypes(
  229. #if __METAL_VERSION__ >= 300
  230. constant IndexAB * indexAB [[buffer(0)]],
  231. #else
  232. constant IndexAB & indexAB [[buffer(0)]],
  233. #endif
  234. constant void * indexSizes [[buffer(1)]],
  235. constant void * indexStrides [[buffer(2)]],
  236. constant OffsetsT * offsets [[buffer(3)]],
  237. constant void * inputData [[buffer(4)]],
  238. device void * outputData [[buffer(5)]],
  239. constant uint32_t & num_indices [[buffer(6)]],
  240. uint thread_index [[thread_position_in_grid]]) {
  241. constant int64_t * index_sizes = (constant int64_t *)indexSizes;
  242. constant int64_t * index_strides = (constant int64_t *)indexStrides;
  243. int64_t offset = 0;
  244. for (uint32_t i = 0; i < num_indices; i++) {
  245. #if __METAL_VERSION__ >= 300
  246. constant int64_t* indexArray = indexAB[i].indexArray;
  247. #else
  248. constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
  249. #endif
  250. int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
  251. if (index < 0) {
  252. index += index_sizes[i];
  253. }
  254. offset += index * index_strides[i];
  255. }
  256. device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
  257. constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
  258. atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
  259. }
  260. template<typename T>
  261. __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
  262. device atomic_uint* uintAddr = (device atomic_uint*)addr;
  263. uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
  264. T updated = as_type<T>(expected) + value;
  265. while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
  266. updated = as_type<T>(expected) + value;
  267. }
  268. }
  269. template<typename T, typename OffsetsT>
  270. kernel void atomic_index_put_accumulate(
  271. #if __METAL_VERSION__ >= 300
  272. constant IndexAB * indexAB [[buffer(0)]],
  273. #else
  274. constant IndexAB & indexAB [[buffer(0)]],
  275. #endif
  276. constant void * indexSizes [[buffer(1)]],
  277. constant void * indexStrides [[buffer(2)]],
  278. constant OffsetsT * offsets [[buffer(3)]],
  279. constant void * inputData [[buffer(4)]],
  280. device void * outputData [[buffer(5)]],
  281. constant uint32_t & num_indices [[buffer(6)]],
  282. uint thread_index [[thread_position_in_grid]]) {
  283. constant int64_t * index_sizes = (constant int64_t *)indexSizes;
  284. constant int64_t * index_strides = (constant int64_t *)indexStrides;
  285. int64_t offset = 0;
  286. for (uint32_t i = 0; i < num_indices; i++) {
  287. #if __METAL_VERSION__ >= 300
  288. constant int64_t* indexArray = indexAB[i].indexArray;
  289. #else
  290. constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
  291. #endif
  292. int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
  293. if (index < 0) {
  294. index += index_sizes[i];
  295. }
  296. offset += index * index_strides[i];
  297. }
  298. device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
  299. constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
  300. atomic_fetch_add_relaxed<T>(out, *in);
  301. }
  302. template
  303. [[host_name("index_put_accumulate_32bit_float_idx32")]]
  304. kernel void atomic_index_put_accumulate<float, uint3>(
  305. #if __METAL_VERSION__ >= 300
  306. constant IndexAB * indexAB [[buffer(0)]],
  307. #else
  308. constant IndexAB & indexAB [[buffer(0)]],
  309. #endif
  310. constant void * indexSizes [[buffer(1)]],
  311. constant void * indexStrides [[buffer(2)]],
  312. constant uint3 * offsets [[buffer(3)]],
  313. constant void * inputData [[buffer(4)]],
  314. device void * outputData [[buffer(5)]],
  315. constant uint32_t & num_indices [[buffer(6)]],
  316. uint thread_index [[thread_position_in_grid]]);
  317. template
  318. [[host_name("index_put_accumulate_32bit_float_idx64")]]
  319. kernel void atomic_index_put_accumulate<float, ulong3>(
  320. #if __METAL_VERSION__ >= 300
  321. constant IndexAB * indexAB [[buffer(0)]],
  322. #else
  323. constant IndexAB & indexAB [[buffer(0)]],
  324. #endif
  325. constant void * indexSizes [[buffer(1)]],
  326. constant void * indexStrides [[buffer(2)]],
  327. constant ulong3 * offsets [[buffer(3)]],
  328. constant void * inputData [[buffer(4)]],
  329. device void * outputData [[buffer(5)]],
  330. constant uint32_t & num_indices [[buffer(6)]],
  331. uint thread_index [[thread_position_in_grid]]);
  332. template
  333. [[host_name("index_put_accumulate_32bit_int_idx32")]]
  334. kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
  335. #if __METAL_VERSION__ >= 300
  336. constant IndexAB * indexAB [[buffer(0)]],
  337. #else
  338. constant IndexAB & indexAB [[buffer(0)]],
  339. #endif
  340. constant void * indexSizes [[buffer(1)]],
  341. constant void * indexStrides [[buffer(2)]],
  342. constant uint3 * offsets [[buffer(3)]],
  343. constant void * inputData [[buffer(4)]],
  344. device void * outputData [[buffer(5)]],
  345. constant uint32_t & num_indices [[buffer(6)]],
  346. uint thread_index [[thread_position_in_grid]]);
  347. template
  348. [[host_name("index_put_accumulate_32bit_int_idx64")]]
  349. kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
  350. #if __METAL_VERSION__ >= 300
  351. constant IndexAB * indexAB [[buffer(0)]],
  352. #else
  353. constant IndexAB & indexAB [[buffer(0)]],
  354. #endif
  355. constant void * indexSizes [[buffer(1)]],
  356. constant void * indexStrides [[buffer(2)]],
  357. constant ulong3 * offsets [[buffer(3)]],
  358. constant void * inputData [[buffer(4)]],
  359. device void * outputData [[buffer(5)]],
  360. constant uint32_t & num_indices [[buffer(6)]],
  361. uint thread_index [[thread_position_in_grid]]);
  362. )INDEX_METAL";
  363. static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
  364. struct __attribute__ ((packed)) packed_uint5{{
  365. uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
  366. }};
  367. template<typename Y, typename X>
  368. Y cast(const X x);
  369. template<>
  370. {1} cast<{1}, {0}>(const {0} x) {{
  371. return {2};
  372. }}
  373. kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
  374. constant void * src_ [[buffer(0)]],
  375. device void * dst_ [[buffer(1)]],
  376. constant packed_uint5 & size [[buffer(2)]],
  377. constant packed_uint5 & stride [[buffer(3)]],
  378. constant uint32_t & numel [[buffer(4)]]) {{
  379. if (linear_index >= numel) return;
  380. constant {0} * src = (constant {0} *)src_;
  381. device {1} * dst = (device {1} *)dst_;
  382. packed_uint5 local_index;
  383. local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
  384. local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
  385. local_index.z = linear_index / (size.u * size.w) % size.z;
  386. local_index.w = linear_index / size.u % size.w;
  387. local_index.u = linear_index % size.u;
  388. packed_uint5 strided_index;
  389. strided_index.x = local_index.x * stride.x;
  390. strided_index.y = local_index.y * stride.y;
  391. strided_index.z = local_index.z * stride.z;
  392. strided_index.w = local_index.w * stride.w;
  393. strided_index.u = local_index.u * stride.u;
  394. dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
  395. }}
  396. kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
  397. constant void * src_ [[buffer(0)]],
  398. device void * dst_ [[buffer(1)]],
  399. constant packed_uint4 & size [[buffer(2)]],
  400. constant packed_uint4 & stride [[buffer(3)]],
  401. constant uint32_t & numel [[buffer(4)]]) {{
  402. if (linear_index >= numel) return;
  403. constant {0} * src = (constant {0} *)src_;
  404. device {1} * dst = (device {1} *)dst_;
  405. packed_uint4 local_index;
  406. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  407. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  408. local_index.z = linear_index / size[3] % size[2];
  409. local_index.w = linear_index % size[3];
  410. const packed_uint4 strided_index = local_index * stride;
  411. dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
  412. }}
  413. kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
  414. constant void * src_ [[buffer(0)]],
  415. device void * dst_ [[buffer(1)]],
  416. constant packed_uint3 & size [[buffer(2)]],
  417. constant packed_uint3 & stride [[buffer(3)]],
  418. constant uint32_t & numel [[buffer(4)]]) {{
  419. if (linear_index >= numel) return;
  420. constant {0} * src = (constant {0} *)src_;
  421. device {1} * dst = (device {1} *)dst_;
  422. packed_uint3 local_index;
  423. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  424. local_index.y = linear_index / size[2] % size[1];
  425. local_index.z = linear_index % size[2];
  426. const packed_uint3 strided_index = local_index * stride;
  427. dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
  428. }}
  429. kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
  430. constant void * src_ [[buffer(0)]],
  431. device void * dst_ [[buffer(1)]],
  432. constant packed_uint2 & size [[buffer(2)]],
  433. constant packed_uint2 & stride [[buffer(3)]],
  434. constant uint32_t & numel [[buffer(4)]]) {{
  435. if (linear_index >= numel) return;
  436. constant {0} * src = (constant {0} *)src_;
  437. device {1} * dst = (device {1} *)dst_;
  438. packed_uint2 local_index;
  439. local_index.x = linear_index / size[1] % size[0];
  440. local_index.y = linear_index % size[1];
  441. const packed_uint2 strided_index = local_index * stride;
  442. dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
  443. }}
  444. kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
  445. constant void * src_ [[buffer(0)]],
  446. device void * dst_ [[buffer(1)]],
  447. constant int & size [[buffer(2)]],
  448. constant int & stride [[buffer(3)]],
  449. constant uint32_t & numel [[buffer(4)]]) {{
  450. if (linear_index >= numel) return;
  451. constant {0} * src = (constant {0} *)src_;
  452. device {1} * dst = (device {1} *)dst_;
  453. const int local_index = linear_index % size;
  454. const int strided_index = local_index * stride;
  455. dst[strided_index] = cast<{1}>(src[linear_index]);
  456. }}
  457. )METAL_SCATTER";
  458. static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
  459. struct __attribute__ ((packed)) packed_uint5{{
  460. uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
  461. }};
  462. template<typename Y, typename X>
  463. Y cast(const X x);
  464. template<>
  465. {1} cast<{1}, {0}>(const {0} x) {{
  466. return {2};
  467. }}
  468. kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
  469. constant void * src_ [[buffer(0)]],
  470. device void * dst_ [[buffer(1)]],
  471. constant packed_uint5 & size [[buffer(2)]],
  472. constant packed_uint5 & stride [[buffer(3)]],
  473. constant uint32_t & numel [[buffer(4)]]) {{
  474. if (linear_index >= numel) return;
  475. constant {0} * src = (constant {0} *)src_;
  476. device {1} * dst = (device {1} *)dst_;
  477. packed_uint5 local_index;
  478. local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
  479. local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
  480. local_index.z = linear_index / (size.u * size.w) % size.z;
  481. local_index.w = linear_index / size.u % size.w;
  482. local_index.u = linear_index % size.u;
  483. packed_uint5 strided_index;
  484. strided_index.x = local_index.x * stride.x;
  485. strided_index.y = local_index.y * stride.y;
  486. strided_index.z = local_index.z * stride.z;
  487. strided_index.w = local_index.w * stride.w;
  488. strided_index.u = local_index.u * stride.u;
  489. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
  490. }}
  491. kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
  492. constant void * src_ [[buffer(0)]],
  493. device void * dst_ [[buffer(1)]],
  494. constant packed_uint4 & size [[buffer(2)]],
  495. constant packed_uint4 & stride [[buffer(3)]],
  496. constant uint32_t & numel [[buffer(4)]]) {{
  497. if (linear_index >= numel) return;
  498. constant {0} * src = (constant {0} *)src_;
  499. device {1} * dst = (device {1} *)dst_;
  500. packed_uint4 local_index;
  501. local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
  502. local_index.y = linear_index / (size[3] * size[2]) % size[1];
  503. local_index.z = linear_index / size[3] % size[2];
  504. local_index.w = linear_index % size[3];
  505. const packed_uint4 strided_index = local_index * stride;
  506. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
  507. }}
  508. kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
  509. constant void * src_ [[buffer(0)]],
  510. device void * dst_ [[buffer(1)]],
  511. constant packed_uint3 & size [[buffer(2)]],
  512. constant packed_uint3 & stride [[buffer(3)]],
  513. constant uint32_t & numel [[buffer(4)]]) {{
  514. if (linear_index >= numel) return;
  515. constant {0} * src = (constant {0} *)src_;
  516. device {1} * dst = (device {1} *)dst_;
  517. packed_uint3 local_index;
  518. local_index.x = linear_index / (size[2] * size[1]) % size[0];
  519. local_index.y = linear_index / size[2] % size[1];
  520. local_index.z = linear_index % size[2];
  521. const packed_uint3 strided_index = local_index * stride;
  522. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
  523. }}
  524. kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
  525. constant void * src_ [[buffer(0)]],
  526. device void * dst_ [[buffer(1)]],
  527. constant packed_uint2 & size [[buffer(2)]],
  528. constant packed_uint2 & stride [[buffer(3)]],
  529. constant uint32_t & numel [[buffer(4)]]) {{
  530. if (linear_index >= numel) return;
  531. constant {0} * src = (constant {0} *)src_;
  532. device {1} * dst = (device {1} *)dst_;
  533. packed_uint2 local_index;
  534. local_index.x = linear_index / size[1] % size[0];
  535. local_index.y = linear_index % size[1];
  536. const packed_uint2 strided_index = local_index * stride;
  537. dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
  538. }}
  539. kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
  540. constant void * src_ [[buffer(0)]],
  541. device void * dst_ [[buffer(1)]],
  542. constant int & size [[buffer(2)]],
  543. constant int & stride [[buffer(3)]],
  544. constant uint32_t & numel [[buffer(4)]]) {{
  545. if (linear_index >= numel) return;
  546. constant {0} * src = (constant {0} *)src_;
  547. device {1} * dst = (device {1} *)dst_;
  548. const int local_index = linear_index % size;
  549. const int strided_index = local_index * stride;
  550. dst[linear_index] = cast<{1}>(src[strided_index]);
  551. }}
  552. )METAL_GATHER";
  553. } // namespace at::mps