utils.h 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/cpu/vec/vec.h>
  4. #include <c10/util/llvmMathExtras.h>
  5. #ifdef USE_FBGEMM
  6. #include <fbgemm/Fbgemm.h>
  7. #endif
  8. namespace at {
  9. namespace native {
  10. template <typename T>
  11. inline void _store(T* dst, at::vec::Vectorized<T> src) {
  12. src.store(dst);
  13. }
  14. inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
  15. auto res = at::vec::convert_float_bfloat16(src, src);
  16. res.store(dst, at::vec::Vectorized<float>::size());
  17. }
  18. inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
  19. auto res = at::vec::convert_float_half(src, src);
  20. res.store(dst, at::vec::Vectorized<float>::size());
  21. }
  22. inline namespace CPU_CAPABILITY {
  23. template <typename T>
  24. inline T data_index_init(T offset) {
  25. return offset;
  26. }
  27. template <typename T, typename... Args>
  28. inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
  29. offset = data_index_init(offset, std::forward<Args>(args)...);
  30. x = offset % X;
  31. return offset / X;
  32. }
  33. inline bool data_index_step() {
  34. return true;
  35. }
  36. template <typename T, typename... Args>
  37. inline bool data_index_step(T& x, const T& X, Args&&... args) {
  38. if (data_index_step(std::forward<Args>(args)...)) {
  39. x = ((x + 1) == X) ? 0 : (x + 1);
  40. return x == 0;
  41. }
  42. return false;
  43. }
  44. // Helper struct for bfloat16 vectorization
  45. // Useful when you need float as immediate dtype or accumulate dtype
  46. using namespace vec;
  47. struct Vec2 {
  48. Vectorized<float> val0, val1;
  49. Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
  50. Vec2(float v) : val0(v), val1(v) {}
  51. static Vec2 loadu(const BFloat16* ptr) {
  52. auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
  53. return {v0, v1};
  54. }
  55. static Vec2 loadu(const float* ptr) {
  56. return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
  57. }
  58. void store(BFloat16* ptr) const {
  59. Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
  60. val.store(ptr);
  61. }
  62. void store(float* ptr) const {
  63. val0.store(ptr);
  64. val1.store(ptr + Vectorized<float>::size());
  65. }
  66. };
  67. inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
  68. inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
  69. inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
  70. inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
  71. inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
  72. inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
  73. template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
  74. template <> struct VectorizedType<BFloat16> { using type = Vec2; };
  75. template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
  76. // Helper for mixed data type parameter Vec::load
  77. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
  78. return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
  79. }
  80. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
  81. return convert_half_float(Vectorized<Half>::loadu(ptr));
  82. }
  83. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
  84. using Vec = Vectorized<float>;
  85. return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
  86. }
  87. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
  88. return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
  89. }
  90. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
  91. return convert_half_float(Vectorized<Half>::loadu(ptr, count));
  92. }
  93. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
  94. using Vec = Vectorized<float>;
  95. if (count > Vec::size()) {
  96. return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
  97. } else {
  98. return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
  99. }
  100. }
  101. } // namespace
  102. namespace utils {
  103. template <typename T>
  104. T CeilLog2(const T& x) {
  105. if (x <= 2) {
  106. return 1;
  107. }
  108. // Last set bit is floor(log2(x)), floor + 1 is ceil
  109. // except when x is an exact powers of 2, so subtract 1 first
  110. return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
  111. }
  112. // matrix transpose:
  113. // src has shape of M by N, with leading dimension of ld_src
  114. // dst has shape of N by M, with leading dimension of ld_dst
  115. template <typename T>
  116. inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
  117. for (int64_t j = 0; j < N; j++) {
  118. for (int64_t i = 0; i < M; i++) {
  119. dst[j * ld_dst + i] = src[i * ld_src + j];
  120. }
  121. }
  122. }
  123. #ifdef USE_FBGEMM
  124. template <>
  125. inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
  126. TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
  127. fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
  128. }
  129. #endif
  130. template <typename index_t, typename F>
  131. inline void parallel_sparse_csr(
  132. const TensorAccessor<index_t, 1>& crow_acc,
  133. const int64_t M,
  134. const int64_t nnz,
  135. const F& f) {
  136. TORCH_CHECK(crow_acc.size(0) == M + 1);
  137. // directly parallel on `M` may lead to load imbalance,
  138. // statically determine thread partition here to average payload
  139. // for each thread.
  140. int num_threads = at::get_num_threads();
  141. std::vector<int64_t> thread_splits(num_threads + 1, M);
  142. int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
  143. thread_splits[0] = 0;
  144. int64_t sum = 0;
  145. int64_t t = 1;
  146. for (const auto m : c10::irange(M)) {
  147. int64_t row_start = crow_acc[m];
  148. int64_t row_end = crow_acc[m + 1];
  149. sum += row_end - row_start;
  150. if (sum > t * thread_averge_payload) {
  151. thread_splits[t] = m;
  152. t++;
  153. }
  154. }
  155. // need to restore the last index,
  156. // due to rounding error when calculating `thread_averge_payload`.
  157. thread_splits[num_threads] = M;
  158. at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
  159. int tid = at::get_thread_num();
  160. int64_t begin = thread_splits[tid];
  161. int64_t end = thread_splits[tid + 1];
  162. f(begin, end);
  163. });
  164. }
  165. } // namespace utils
  166. } // namespace native
  167. } // namespace at