TopKImpl.h 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #pragma once
  2. #include <ATen/core/TensorAccessor.h>
  3. #include <ATen/NumericUtils.h>
  4. namespace at::native {
  5. #ifdef CPU_CAPABILITY
  6. inline namespace CPU_CAPABILITY {
  7. #else
  8. inline namespace DEFAULT {
  9. #endif
  10. // Core topk loop, shared between CPU and QuantizedCPU
  11. template <typename scalar_t, typename accscalar_t>
  12. void topk_impl_loop(
  13. const int64_t mode_values_stride,
  14. const int64_t mode_indices_stride,
  15. const int64_t tmp_values_stride,
  16. const int64_t k,
  17. const int64_t dim_size,
  18. const bool largest,
  19. const bool sorted,
  20. char** data, const int64_t* strides, const int64_t n) {
  21. // If k is zero, then output values and indices are empty tensors
  22. // So iterating over other dims is pointless
  23. if (k == 0) {
  24. return;
  25. }
  26. using elem_t = std::pair<accscalar_t, int64_t>;
  27. std::vector<elem_t> queue(dim_size);
  28. for (const auto i : c10::irange(n)) {
  29. TensorAccessor<scalar_t, 1> mode_values(
  30. reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
  31. &k, &mode_values_stride);
  32. TensorAccessor<int64_t, 1> mode_indices(
  33. reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
  34. &k, &mode_indices_stride);
  35. TensorAccessor<const scalar_t, 1> tmp_values(
  36. reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
  37. &dim_size, &tmp_values_stride);
  38. auto n_2 = dim_size;
  39. auto use_partial_sort = k * 64 <= n_2;
  40. for (const auto j : c10::irange(n_2)) {
  41. queue[j].first = tmp_values[j];
  42. queue[j].second = j;
  43. }
  44. // we want nan to be sorted as top for numpy compatibility
  45. if (use_partial_sort) {
  46. if (largest) {
  47. std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
  48. [](const elem_t& x, const elem_t& y) -> bool {
  49. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  50. });
  51. } else {
  52. std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
  53. [](const elem_t& x, const elem_t& y) -> bool {
  54. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  55. });
  56. }
  57. } else {
  58. if (largest) {
  59. std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
  60. [](const elem_t& x, const elem_t& y) -> bool {
  61. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  62. });
  63. if (sorted) {
  64. std::sort(queue.begin(), queue.begin() + k - 1,
  65. [](const elem_t& x, const elem_t& y) -> bool {
  66. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  67. });
  68. }
  69. } else {
  70. std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
  71. [](const elem_t& x, const elem_t& y) -> bool {
  72. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  73. });
  74. if (sorted) {
  75. std::sort(queue.begin(), queue.begin() + k -1,
  76. [](const elem_t& x, const elem_t& y) -> bool {
  77. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  78. });
  79. }
  80. }
  81. }
  82. for (const auto j : c10::irange(k)) {
  83. mode_values[j] = queue[j].first;
  84. mode_indices[j] = queue[j].second;
  85. }
  86. }
  87. }
  88. } // namespace CPU_CAPABILITY
  89. } // namespace at::native