CallOnce.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #pragma once
  2. #include <atomic>
  3. #include <mutex>
  4. #include <utility>
  5. #include <c10/macros/Macros.h>
  6. #include <c10/util/C++17.h>
  7. namespace c10 {
  8. // custom c10 call_once implementation to avoid the deadlock in std::call_once.
  9. // The implementation here is a simplified version from folly and likely much
  10. // much higher memory footprint.
  11. template <typename Flag, typename F, typename... Args>
  12. inline void call_once(Flag& flag, F&& f, Args&&... args) {
  13. if (C10_LIKELY(flag.test_once())) {
  14. return;
  15. }
  16. flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
  17. }
  18. class once_flag {
  19. public:
  20. #ifndef _WIN32
  21. // running into build error on MSVC. Can't seem to get a repro locally so I'm
  22. // just avoiding constexpr
  23. //
  24. // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
  25. // defaulted default constructor cannot be constexpr because the
  26. // corresponding implicitly declared default constructor would not be
  27. // constexpr 1 error detected in the compilation of
  28. // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
  29. constexpr
  30. #endif
  31. once_flag() noexcept = default;
  32. once_flag(const once_flag&) = delete;
  33. once_flag& operator=(const once_flag&) = delete;
  34. private:
  35. template <typename Flag, typename F, typename... Args>
  36. friend void call_once(Flag& flag, F&& f, Args&&... args);
  37. template <typename F, typename... Args>
  38. void call_once_slow(F&& f, Args&&... args) {
  39. std::lock_guard<std::mutex> guard(mutex_);
  40. if (init_.load(std::memory_order_relaxed)) {
  41. return;
  42. }
  43. c10::guts::invoke(std::forward<F>(f), std::forward<Args>(args)...);
  44. init_.store(true, std::memory_order_release);
  45. }
  46. bool test_once() {
  47. return init_.load(std::memory_order_acquire);
  48. }
  49. void reset_once() {
  50. init_.store(false, std::memory_order_release);
  51. }
  52. private:
  53. std::mutex mutex_;
  54. std::atomic<bool> init_{false};
  55. };
  56. } // namespace c10