KernelUtils.cuh 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #pragma once
  2. #include <ATen/cuda/Atomic.cuh>
  3. #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
  4. #include <cuda_bf16.h>
  5. #endif
  6. namespace at {
  7. namespace native {
  8. __device__ __forceinline__ size_t
  9. idx(const size_t nc,
  10. const size_t height,
  11. const size_t width,
  12. const size_t h,
  13. const size_t w) {
  14. return (nc * height + h) * width + w;
  15. }
  16. // for channels-last
  17. __device__ __forceinline__ size_t
  18. idx_cl(
  19. const size_t n, const size_t h, const size_t w, const size_t c,
  20. const size_t height, const size_t width, const size_t channel
  21. ) {
  22. return ((n * height + h) * width + w) * channel + c;
  23. }
  24. // fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization
  25. // that speed up half-precision atomics. The situation with half
  26. // precision atomics is that we have a slow __half atomic, and
  27. // a fast vectored __half2 atomic (this can be worth up to a 6x
  28. // speedup, see https://github.com/pytorch/pytorch/pull/21879).
  29. // We can convert a __half atomic into a __half2 atomic by simply
  30. // pairing the __half with a zero entry on the left/right depending
  31. // on alignment... but only if this wouldn't cause an out of bounds
  32. // access! Thus, you must specify tensor and numel so we can check
  33. // if you would be out-of-bounds and use a plain __half atomic if
  34. // you would be.
  35. template <
  36. typename scalar_t,
  37. typename index_t,
  38. typename std::enable_if<std::is_same<c10::Half, scalar_t>::value>::type* =
  39. nullptr>
  40. __device__ __forceinline__ void fastSpecializedAtomicAdd(
  41. scalar_t* tensor,
  42. index_t index,
  43. const index_t numel,
  44. scalar_t value) {
  45. #if ( \
  46. (defined(USE_ROCM)) || \
  47. (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
  48. gpuAtomicAddNoReturn(
  49. reinterpret_cast<at::Half*>(tensor) + index,
  50. static_cast<at::Half>(value));
  51. #else
  52. // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
  53. __half* target_addr = reinterpret_cast<__half*>(tensor + index);
  54. bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
  55. if (low_byte && index < (numel - 1)) {
  56. __half2 value2;
  57. value2.x = static_cast<__half>(value);
  58. value2.y = __int2half_rz(0);
  59. atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
  60. } else if (!low_byte && index > 0) {
  61. __half2 value2;
  62. value2.x = __int2half_rz(0);
  63. value2.y = static_cast<__half>(value);
  64. atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
  65. } else {
  66. atomicAdd(
  67. reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
  68. }
  69. #endif
  70. }
  71. template <
  72. typename scalar_t,
  73. typename index_t,
  74. typename std::enable_if<std::is_same<c10::BFloat16, scalar_t>::value>::type* =
  75. nullptr>
  76. __device__ __forceinline__ void fastSpecializedAtomicAdd(
  77. scalar_t* tensor,
  78. index_t index,
  79. const index_t numel,
  80. scalar_t value) {
  81. #if ( \
  82. (defined(USE_ROCM)) || \
  83. (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
  84. gpuAtomicAddNoReturn(
  85. reinterpret_cast<at::BFloat16*>(tensor) + index,
  86. static_cast<at::BFloat16>(value));
  87. #else
  88. // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
  89. __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
  90. bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat162) == 0);
  91. if (low_byte && index < (numel - 1)) {
  92. __nv_bfloat162 value2;
  93. value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
  94. value2.y = __int2bfloat16_rz(0);
  95. atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
  96. } else if (!low_byte && index > 0) {
  97. __nv_bfloat162 value2;
  98. value2.x = __int2bfloat16_rz(0);
  99. value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
  100. atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
  101. } else {
  102. atomicAdd(
  103. reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
  104. }
  105. #endif
  106. }
  107. template <
  108. typename scalar_t,
  109. typename index_t,
  110. typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value && !std::is_same<c10::BFloat16, scalar_t>::value >::type* =
  111. nullptr>
  112. __device__ __forceinline__ void fastSpecializedAtomicAdd(
  113. scalar_t* tensor,
  114. index_t index,
  115. const index_t numel,
  116. scalar_t value) {
  117. gpuAtomicAddNoReturn(tensor + index, value);
  118. }
  119. template <class scalar_t, class index_t>
  120. __device__ __forceinline__ void fastAtomicAdd(
  121. scalar_t* tensor,
  122. index_t index,
  123. const index_t numel,
  124. scalar_t value,
  125. bool fast_atomics) {
  126. if (fast_atomics) {
  127. fastSpecializedAtomicAdd(tensor, index, numel, value);
  128. } else {
  129. gpuAtomicAddNoReturn(tensor + index, value);
  130. }
  131. }
  132. } // namespace native
  133. } // namespace at