| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630 |
- #pragma once
- namespace at::mps {
- static const char * indexing_metal_shaders = R"INDEX_METAL(
- #include <metal_stdlib>
- #include <metal_atomic>
- using namespace metal;
- #if __METAL_VERSION__ < 300
- struct IndexAB {
- // Allow up to 16 indices
- metal::array<constant void *, 16> indexArray [[ id(0) ]];
- };
- #else
- struct IndexAB {
- constant int64_t* indexArray;
- };
- #endif
- template<typename T, typename OffsetsT>
- kernel void index_select(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant OffsetsT * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]) {
- constant int64_t * index_sizes = (constant int64_t *)indexSizes;
- constant int64_t * index_strides = (constant int64_t *)indexStrides;
- int64_t offset = 0;
- for (uint32_t i = 0; i < num_indices; i++) {
- #if __METAL_VERSION__ >= 300
- constant int64_t* indexArray = indexAB[i].indexArray;
- #else
- constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
- #endif
- int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
- if (index < 0) {
- index += index_sizes[i];
- }
- offset += index * index_strides[i];
- }
- device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
- *out = *in;
- }
- template<typename T, typename OffsetsT>
- void index_put_impl(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB,
- #else
- constant IndexAB & indexAB,
- #endif
- constant int64_t * index_sizes,
- constant int64_t * index_strides,
- constant OffsetsT * offsets,
- constant void * inputData,
- device void * outputData,
- constant uint32_t & num_indices,
- uint thread_index) {
- int64_t offset = 0;
- for (uint32_t i = 0; i < num_indices; i++) {
- #if __METAL_VERSION__ >= 300
- constant int64_t* indexArray = indexAB[i].indexArray;
- #else
- constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
- #endif
- int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
- if (index < 0) {
- index += index_sizes[i];
- }
- offset += index * index_strides[i];
- }
- device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
- *out = *in;
- }
- template<typename T, typename OffsetsT>
- kernel void index_put_serial(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant OffsetsT * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- constant uint * numIters [[buffer(7)]],
- uint thread_index [[thread_position_in_grid]]) {
- constant int64_t * index_sizes = (constant int64_t *)indexSizes;
- constant int64_t * index_strides = (constant int64_t *)indexStrides;
- for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
- index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
- }
- }
- template<typename T, typename OffsetsT>
- kernel void index_put(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant OffsetsT * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]) {
- constant int64_t * index_sizes = (constant int64_t *)indexSizes;
- constant int64_t * index_strides = (constant int64_t *)indexStrides;
- index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
- }
- #if __METAL_VERSION__ < 300
- #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
- template \
- [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
- kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
- constant IndexAB & indexAB [[buffer(0)]], \
- constant void * indexSizes [[buffer(1)]], \
- constant void * indexStrides [[buffer(2)]], \
- constant IDX_DTYPE * offsets [[buffer(3)]], \
- constant void * inputData [[buffer(4)]], \
- device void * outputData [[buffer(5)]], \
- constant uint32_t & num_indices [[buffer(6)]], \
- uint thread_index [[thread_position_in_grid]]);
- #else
- #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
- template \
- [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
- kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
- constant IndexAB * indexAB [[buffer(0)]], \
- constant void * indexSizes [[buffer(1)]], \
- constant void * indexStrides [[buffer(2)]], \
- constant IDX_DTYPE * offsets [[buffer(3)]], \
- constant void * inputData [[buffer(4)]], \
- device void * outputData [[buffer(5)]], \
- constant uint32_t & num_indices [[buffer(6)]], \
- uint thread_index [[thread_position_in_grid]]);
- #endif
- #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
- REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
- REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
- REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
- REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
- REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
- REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
- REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
- REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
- REGISTER_INDEX_OP_ALL_DTYPES(select);
- REGISTER_INDEX_OP_ALL_DTYPES(put);
- #if __METAL_VERSION__ < 300
- #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
- template \
- [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
- kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
- constant IndexAB & indexAB [[buffer(0)]], \
- constant void * indexSizes [[buffer(1)]], \
- constant void * indexStrides [[buffer(2)]], \
- constant IDX_DTYPE * offsets [[buffer(3)]], \
- constant void * inputData [[buffer(4)]], \
- device void * outputData [[buffer(5)]], \
- constant uint32_t & num_indices [[buffer(6)]], \
- constant uint * numIters [[buffer(7)]], \
- uint thread_index [[thread_position_in_grid]]);
- #else
- #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
- template \
- [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
- kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
- constant IndexAB * indexAB [[buffer(0)]], \
- constant void * indexSizes [[buffer(1)]], \
- constant void * indexStrides [[buffer(2)]], \
- constant IDX_DTYPE * offsets [[buffer(3)]], \
- constant void * inputData [[buffer(4)]], \
- device void * outputData [[buffer(5)]], \
- constant uint32_t & num_indices [[buffer(6)]], \
- constant uint * numIters [[buffer(7)]], \
- uint thread_index [[thread_position_in_grid]]);
- #endif
- #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
- REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
- REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
- REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
- template<typename StridesT, typename DataT>
- kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
- device DataT * data_offsets [[buffer(1)]],
- constant uint * iter_shape [[buffer(2)]],
- constant uint & num_dimensions [[buffer(3)]],
- uint thread_index [[thread_position_in_grid]]) {
- data_offsets[thread_index] = 0;
- uint32_t idx = thread_index;
- for (uint32_t dim = 0; dim < num_dimensions; dim++) {
- uint32_t remainder = idx % iter_shape[dim];
- idx /= iter_shape[dim];
- data_offsets[thread_index] += remainder * DataT(strides[dim]);
- }
- }
- template
- [[host_name("kernel_index_offsets_32")]]
- kernel void kernel_index_offsets<packed_uint3, uint3>(
- constant packed_uint3 * strides [[buffer(0)]],
- device uint3 * data_offsets [[buffer(1)]],
- constant uint * iter_shape [[buffer(2)]],
- constant uint & num_dimensions [[buffer(3)]],
- uint thread_index [[thread_position_in_grid]]);
- template
- [[host_name("kernel_index_offsets_64")]]
- kernel void kernel_index_offsets<packed_uint3, ulong3>(
- constant packed_uint3 * strides [[buffer(0)]],
- device ulong3 * data_offsets [[buffer(1)]],
- constant uint * iter_shape [[buffer(2)]],
- constant uint & num_dimensions [[buffer(3)]],
- uint thread_index [[thread_position_in_grid]]);
- template<typename T, typename E, typename OffsetsT>
- kernel void index_put_accumulate_native_dtypes(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant OffsetsT * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]) {
- constant int64_t * index_sizes = (constant int64_t *)indexSizes;
- constant int64_t * index_strides = (constant int64_t *)indexStrides;
- int64_t offset = 0;
- for (uint32_t i = 0; i < num_indices; i++) {
- #if __METAL_VERSION__ >= 300
- constant int64_t* indexArray = indexAB[i].indexArray;
- #else
- constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
- #endif
- int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
- if (index < 0) {
- index += index_sizes[i];
- }
- offset += index * index_strides[i];
- }
- device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
- constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
- atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
- }
- template<typename T>
- __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
- device atomic_uint* uintAddr = (device atomic_uint*)addr;
- uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
- T updated = as_type<T>(expected) + value;
- while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
- updated = as_type<T>(expected) + value;
- }
- }
- template<typename T, typename OffsetsT>
- kernel void atomic_index_put_accumulate(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant OffsetsT * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]) {
- constant int64_t * index_sizes = (constant int64_t *)indexSizes;
- constant int64_t * index_strides = (constant int64_t *)indexStrides;
- int64_t offset = 0;
- for (uint32_t i = 0; i < num_indices; i++) {
- #if __METAL_VERSION__ >= 300
- constant int64_t* indexArray = indexAB[i].indexArray;
- #else
- constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
- #endif
- int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
- if (index < 0) {
- index += index_sizes[i];
- }
- offset += index * index_strides[i];
- }
- device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
- constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
- atomic_fetch_add_relaxed<T>(out, *in);
- }
- template
- [[host_name("index_put_accumulate_32bit_float_idx32")]]
- kernel void atomic_index_put_accumulate<float, uint3>(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant uint3 * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]);
- template
- [[host_name("index_put_accumulate_32bit_float_idx64")]]
- kernel void atomic_index_put_accumulate<float, ulong3>(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant ulong3 * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]);
- template
- [[host_name("index_put_accumulate_32bit_int_idx32")]]
- kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant uint3 * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]);
- template
- [[host_name("index_put_accumulate_32bit_int_idx64")]]
- kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
- #if __METAL_VERSION__ >= 300
- constant IndexAB * indexAB [[buffer(0)]],
- #else
- constant IndexAB & indexAB [[buffer(0)]],
- #endif
- constant void * indexSizes [[buffer(1)]],
- constant void * indexStrides [[buffer(2)]],
- constant ulong3 * offsets [[buffer(3)]],
- constant void * inputData [[buffer(4)]],
- device void * outputData [[buffer(5)]],
- constant uint32_t & num_indices [[buffer(6)]],
- uint thread_index [[thread_position_in_grid]]);
- )INDEX_METAL";
- static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
- struct __attribute__ ((packed)) packed_uint5{{
- uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
- }};
- template<typename Y, typename X>
- Y cast(const X x);
- template<>
- {1} cast<{1}, {0}>(const {0} x) {{
- return {2};
- }}
- kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint5 & size [[buffer(2)]],
- constant packed_uint5 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint5 local_index;
- local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
- local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
- local_index.z = linear_index / (size.u * size.w) % size.z;
- local_index.w = linear_index / size.u % size.w;
- local_index.u = linear_index % size.u;
- packed_uint5 strided_index;
- strided_index.x = local_index.x * stride.x;
- strided_index.y = local_index.y * stride.y;
- strided_index.z = local_index.z * stride.z;
- strided_index.w = local_index.w * stride.w;
- strided_index.u = local_index.u * stride.u;
- dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint4 & size [[buffer(2)]],
- constant packed_uint4 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint4 local_index;
- local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
- local_index.y = linear_index / (size[3] * size[2]) % size[1];
- local_index.z = linear_index / size[3] % size[2];
- local_index.w = linear_index % size[3];
- const packed_uint4 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint3 & size [[buffer(2)]],
- constant packed_uint3 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint3 local_index;
- local_index.x = linear_index / (size[2] * size[1]) % size[0];
- local_index.y = linear_index / size[2] % size[1];
- local_index.z = linear_index % size[2];
- const packed_uint3 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint2 & size [[buffer(2)]],
- constant packed_uint2 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint2 local_index;
- local_index.x = linear_index / size[1] % size[0];
- local_index.y = linear_index % size[1];
- const packed_uint2 strided_index = local_index * stride;
- dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
- }}
- kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant int & size [[buffer(2)]],
- constant int & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- const int local_index = linear_index % size;
- const int strided_index = local_index * stride;
- dst[strided_index] = cast<{1}>(src[linear_index]);
- }}
- )METAL_SCATTER";
- static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
- struct __attribute__ ((packed)) packed_uint5{{
- uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
- }};
- template<typename Y, typename X>
- Y cast(const X x);
- template<>
- {1} cast<{1}, {0}>(const {0} x) {{
- return {2};
- }}
- kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint5 & size [[buffer(2)]],
- constant packed_uint5 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint5 local_index;
- local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
- local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
- local_index.z = linear_index / (size.u * size.w) % size.z;
- local_index.w = linear_index / size.u % size.w;
- local_index.u = linear_index % size.u;
- packed_uint5 strided_index;
- strided_index.x = local_index.x * stride.x;
- strided_index.y = local_index.y * stride.y;
- strided_index.z = local_index.z * stride.z;
- strided_index.w = local_index.w * stride.w;
- strided_index.u = local_index.u * stride.u;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
- }}
- kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint4 & size [[buffer(2)]],
- constant packed_uint4 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint4 local_index;
- local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
- local_index.y = linear_index / (size[3] * size[2]) % size[1];
- local_index.z = linear_index / size[3] % size[2];
- local_index.w = linear_index % size[3];
- const packed_uint4 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
- }}
- kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint3 & size [[buffer(2)]],
- constant packed_uint3 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint3 local_index;
- local_index.x = linear_index / (size[2] * size[1]) % size[0];
- local_index.y = linear_index / size[2] % size[1];
- local_index.z = linear_index % size[2];
- const packed_uint3 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
- }}
- kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant packed_uint2 & size [[buffer(2)]],
- constant packed_uint2 & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- packed_uint2 local_index;
- local_index.x = linear_index / size[1] % size[0];
- local_index.y = linear_index % size[1];
- const packed_uint2 strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
- }}
- kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
- constant void * src_ [[buffer(0)]],
- device void * dst_ [[buffer(1)]],
- constant int & size [[buffer(2)]],
- constant int & stride [[buffer(3)]],
- constant uint32_t & numel [[buffer(4)]]) {{
- if (linear_index >= numel) return;
- constant {0} * src = (constant {0} *)src_;
- device {1} * dst = (device {1} *)dst_;
- const int local_index = linear_index % size;
- const int strided_index = local_index * stride;
- dst[linear_index] = cast<{1}>(src[strided_index]);
- }}
- )METAL_GATHER";
- } // namespace at::mps
|