Parallel-inl.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #pragma once
  2. #include <c10/util/Exception.h>
  3. #include <c10/util/ParallelGuard.h>
  4. #include <c10/util/SmallVector.h>
  5. namespace at {
  6. template <class F>
  7. inline void parallel_for(
  8. const int64_t begin,
  9. const int64_t end,
  10. const int64_t grain_size,
  11. const F& f) {
  12. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
  13. if (begin >= end) {
  14. return;
  15. }
  16. #ifdef INTRA_OP_PARALLEL
  17. at::internal::lazy_init_num_threads();
  18. const auto numiter = end - begin;
  19. const bool use_parallel =
  20. (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
  21. at::get_num_threads() > 1);
  22. if (!use_parallel) {
  23. internal::ThreadIdGuard tid_guard(0);
  24. c10::ParallelGuard guard(true);
  25. f(begin, end);
  26. return;
  27. }
  28. internal::invoke_parallel(
  29. begin, end, grain_size, [&](int64_t begin, int64_t end) {
  30. c10::ParallelGuard guard(true);
  31. f(begin, end);
  32. });
  33. #else
  34. internal::ThreadIdGuard tid_guard(0);
  35. c10::ParallelGuard guard(true);
  36. f(begin, end);
  37. #endif
  38. }
  39. template <class scalar_t, class F, class SF>
  40. inline scalar_t parallel_reduce(
  41. const int64_t begin,
  42. const int64_t end,
  43. const int64_t grain_size,
  44. const scalar_t ident,
  45. const F& f,
  46. const SF& sf) {
  47. TORCH_CHECK(grain_size >= 0);
  48. if (begin >= end) {
  49. return ident;
  50. }
  51. #ifdef INTRA_OP_PARALLEL
  52. at::internal::lazy_init_num_threads();
  53. const auto max_threads = at::get_num_threads();
  54. const bool use_parallel =
  55. ((end - begin) > grain_size && !at::in_parallel_region() &&
  56. max_threads > 1);
  57. if (!use_parallel) {
  58. internal::ThreadIdGuard tid_guard(0);
  59. c10::ParallelGuard guard(true);
  60. return f(begin, end, ident);
  61. }
  62. c10::SmallVector<scalar_t, 64> results(max_threads, ident);
  63. internal::invoke_parallel(
  64. begin,
  65. end,
  66. grain_size,
  67. [&](const int64_t my_begin, const int64_t my_end) {
  68. const auto tid = at::get_thread_num();
  69. c10::ParallelGuard guard(true);
  70. results[tid] = f(my_begin, my_end, ident);
  71. });
  72. scalar_t result = ident;
  73. for (auto partial_result : results) {
  74. result = sf(result, partial_result);
  75. }
  76. return result;
  77. #else
  78. internal::ThreadIdGuard tid_guard(0);
  79. c10::ParallelGuard guard(true);
  80. return f(begin, end, ident);
  81. #endif
  82. }
  83. } // namespace at