ReduceUtils.h 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/NumericUtils.h>
  4. #include <ATen/cpu/vec/vec.h>
  5. #include <ATen/cpu/vec/functional.h>
  6. #include <ATen/native/ReductionType.h>
  7. #include <c10/util/irange.h>
  8. #include <ATen/OpMathType.h>
  9. #include <ATen/native/cpu/utils.h>
  10. #include <ATen/OpMathType.h>
  11. namespace at::native {
  12. inline namespace CPU_CAPABILITY {
  13. using namespace vec;
  14. #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
  15. [&] { \
  16. switch (op) { \
  17. case ReductionType::SUM: { \
  18. static constexpr auto reduce = ReductionType::SUM; \
  19. return __VA_ARGS__(); \
  20. } \
  21. case ReductionType::MEAN: { \
  22. static constexpr auto reduce = ReductionType::MEAN; \
  23. return __VA_ARGS__(); \
  24. } \
  25. case ReductionType::MIN: { \
  26. static constexpr auto reduce = ReductionType::MIN; \
  27. return __VA_ARGS__(); \
  28. } \
  29. case ReductionType::MAX: { \
  30. static constexpr auto reduce = ReductionType::MAX; \
  31. return __VA_ARGS__(); \
  32. } \
  33. case ReductionType::PROD: { \
  34. static constexpr auto reduce = ReductionType::PROD; \
  35. return __VA_ARGS__(); \
  36. } \
  37. } \
  38. }()
  39. template <typename scalar_t, ReductionType reduce>
  40. inline vec_scalar_t<scalar_t> init_value() {
  41. using acc_t = vec_scalar_t<scalar_t>;
  42. acc_t val;
  43. if (reduce == ReductionType::SUM ||
  44. reduce == ReductionType::MEAN) {
  45. val = static_cast<acc_t>(0);
  46. } else if (reduce == ReductionType::PROD) {
  47. val = static_cast<acc_t>(1);
  48. } else if (reduce == ReductionType::MAX) {
  49. val = -std::numeric_limits<acc_t>::infinity();
  50. } else {
  51. TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
  52. val = std::numeric_limits<acc_t>::infinity();
  53. }
  54. return val;
  55. }
  56. template <typename scalar_t, ReductionType reduce>
  57. inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
  58. using acc_t = vec_scalar_t<scalar_t>;
  59. if (initial.has_value()) {
  60. return initial.value().to<acc_t>();
  61. } else {
  62. return init_value<scalar_t, reduce>();
  63. }
  64. }
  65. template <typename scalar_t>
  66. inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
  67. using Vec = Vectorized<vec_scalar_t<scalar_t>>;
  68. map<scalar_t>(
  69. [val](Vec x) { return Vec(val); },
  70. out,
  71. out,
  72. size);
  73. }
  74. template <typename scalar_t, ReductionType reduce>
  75. inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
  76. using acc_t = vec_scalar_t<scalar_t>;
  77. acc_t val = init_value<scalar_t, reduce>(initial);
  78. init(out, size, val);
  79. }
  80. // overload with `include_self`, used by scatter_reduce
  81. template <typename scalar_t, ReductionType reduce>
  82. inline void init(scalar_t* out, int64_t size, bool include_self = false) {
  83. using acc_t = vec_scalar_t<scalar_t>;
  84. if (!include_self) {
  85. acc_t val = init_value<scalar_t, reduce>();
  86. init(out, size, val);
  87. }
  88. }
  89. template <typename scalar_t, ReductionType reduce>
  90. inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
  91. if (!include_self) {
  92. init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
  93. } else {
  94. vec::convert(self_ptr, buffer_ptr, size);
  95. }
  96. }
  97. template <typename scalar_t>
  98. inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
  99. _max(const scalar_t& x, const scalar_t& y) {
  100. return at::_isnan(y) ? y : std::max(x, y);
  101. }
  102. template <typename scalar_t>
  103. inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
  104. // vec::maximum propagates NaN
  105. return vec::maximum(x, y);
  106. }
  107. template <typename vec_t>
  108. inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
  109. _max(const vec_t& x, const vec_t& y) {
  110. // vec::maximum propagates NaN
  111. return maximum(x, y);
  112. }
  113. template <typename scalar_t>
  114. inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
  115. _min(const scalar_t& x, const scalar_t& y) {
  116. return at::_isnan(y) ? y : std::min(x, y);
  117. }
  118. template <typename scalar_t>
  119. inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
  120. // vec::minimum propagates NaN
  121. return vec::minimum(x, y);
  122. }
  123. template <typename vec_t>
  124. inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
  125. _min(const vec_t& x, const vec_t& y) {
  126. // vec::minimum propagates NaN
  127. return minimum(x, y);
  128. }
  129. template <typename scalar_t, typename accumut, typename Op,
  130. typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
  131. inline void map_acc(
  132. const Op& vec_fun,
  133. accumut* output_data,
  134. const accumut* input_data,
  135. const scalar_t* input_data2,
  136. int64_t size) {
  137. using Vec = vec::Vectorized<scalar_t>;
  138. using aVec = vec::Vectorized<accumut>;
  139. int64_t d = 0;
  140. constexpr int64_t kVecSize = Vec::size();
  141. constexpr int64_t kaVecSize = aVec::size();
  142. for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
  143. Vec data2_vec = Vec::loadu(input_data2 + d);
  144. auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
  145. aVec input_vec0 = aVec::loadu(input_data + d);
  146. aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
  147. vec_fun(input_vec0, data2_avec0).store(output_data + d);
  148. vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
  149. }
  150. if (size - d > 0) {
  151. int64_t tail_size = size - d;
  152. Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
  153. auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
  154. if (tail_size > kaVecSize) {
  155. aVec input_vec0 = aVec::loadu(input_data + d);
  156. aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
  157. vec_fun(input_vec0, data2_avec0).store(output_data + d);
  158. vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
  159. } else {
  160. aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
  161. vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
  162. }
  163. }
  164. }
  165. // for Max and Min, propagate NaN:
  166. template <typename T, ReductionType reduce>
  167. inline T update(const T& x, const T& y) {
  168. if (reduce == ReductionType::SUM ||
  169. reduce == ReductionType::MEAN) {
  170. return x + y;
  171. } else if (reduce == ReductionType::PROD) {
  172. return x * y;
  173. } else if (reduce == ReductionType::MAX) {
  174. return _max(x, y);
  175. } else {
  176. TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
  177. return _min(x, y);
  178. }
  179. }
  180. template <typename scalar_t, ReductionType reduce>
  181. inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
  182. using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
  183. map2<scalar_t>(
  184. [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
  185. out,
  186. out,
  187. data,
  188. K);
  189. }
  190. template <typename scalar_t, ReductionType reduce,
  191. typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
  192. inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
  193. using opmath_t = at::opmath_type<scalar_t>;
  194. using Vec = vec::Vectorized<opmath_t>;
  195. map_acc<scalar_t, opmath_t>(
  196. [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
  197. out,
  198. out,
  199. data,
  200. K);
  201. }
  202. template <typename scalar_t, ReductionType reduce>
  203. inline void write(scalar_t* out, int64_t count, int64_t K) {
  204. using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
  205. if (reduce == ReductionType::MEAN) {
  206. if (count > 0) {
  207. vec::map<scalar_t>(
  208. [count](Vec x) { return x / Vec(count); },
  209. out,
  210. out,
  211. K);
  212. }
  213. }
  214. }
  215. } // namespace CPU_CAPABILITY
  216. } // namespace at::native