ThreadLocalState.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #pragma once
  2. #include <c10/core/InferenceMode.h>
  3. #include <c10/core/impl/LocalDispatchKeySet.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/ThreadLocalDebugInfo.h>
  6. #include <ATen/FuncTorchTLS.h>
  7. #include <ATen/PythonTorchFunctionTLS.h>
  8. #include <ATen/SavedTensorHooks.h>
  9. #include <ATen/ThreadLocalPythonObjects.h>
  10. #include <ATen/record_function.h>
  11. #include <c10/core/impl/PythonDispatcherTLS.h>
  12. #include <c10/core/impl/TorchDispatchModeTLS.h>
  13. namespace at {
  14. // Thread local state contains values that are preserved across
  15. // thread boundaries (e.g. at::launch/JIT fork, autograd).
  16. // Note at::parallel_for doesn't preserve TLS across thread boundaries.
  17. class TORCH_API ThreadLocalState {
  18. public:
  19. // Saves the thread local variables' values and
  20. // returns them as a ThreadLocalState
  21. ThreadLocalState();
  22. // set_grad_mode - force the value of the grad mode TLS in
  23. // the current state object. This is used for example in the
  24. // autograd engine.
  25. void set_grad_mode(bool enabled);
  26. // set_multithreading_enabled - force the value of the multithreadinmaximum
  27. // threads TLS in
  28. // the current state object. This is used for example in the
  29. // autograd engine.
  30. void set_multithreading_enabled(bool enabled);
  31. // Sets thread local variables in the current thread,
  32. // according to the thread boundary specified
  33. static void setThreadLocalState(const ThreadLocalState& state);
  34. private:
  35. c10::impl::LocalDispatchKeySet dispatch_key_;
  36. // ThreadLocalDebugInfo does not change after being created
  37. // with DebugInfoGuard
  38. std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
  39. // RecordFunction TLS
  40. RecordFunctionTLS rf_tls_;
  41. // TLS for out-of-tree functorch
  42. // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
  43. // pointer (spoiler alert: it's due to the indirection)
  44. // This needs to be a shared_ptr instead of a unique_ptr because
  45. // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
  46. // consider adding an explicit copy constructor for ThreadLocalState in the
  47. // future but I didn't want to add one just for this.
  48. std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
  49. // TLS for AutogradModes
  50. AutogradState autograd_tls_;
  51. // TLS for enable_torch_dispatch_mode
  52. c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
  53. // TLS for enable_python_dispatcher
  54. c10::impl::PyInterpreter* python_dispatcher_state_;
  55. // TLS for __torch_function__ (mode and disable_torch_function)
  56. at::impl::PythonTorchFunctionTLS python_torch_function_state_;
  57. // TLS for saved tensors default hooks
  58. at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
  59. bool functionalization_reapply_views_state_;
  60. // TLS for arbitrary python objects that is registered via hooks
  61. at::impl::ThreadLocalPythonObjects saved_objects_;
  62. friend class ThreadLocalStateGuard;
  63. };
  64. // Guard to set and reset the thread local state
  65. class TORCH_API ThreadLocalStateGuard {
  66. public:
  67. explicit ThreadLocalStateGuard(const ThreadLocalState& state)
  68. : prev_state_(ThreadLocalState()) {
  69. // set the given state across the thread boundary
  70. ThreadLocalState::setThreadLocalState(state);
  71. }
  72. ~ThreadLocalStateGuard() {
  73. // restore previously set variables
  74. ThreadLocalState::setThreadLocalState(prev_state_);
  75. }
  76. private:
  77. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  78. const ThreadLocalState prev_state_;
  79. };
  80. template <typename T>
  81. auto wrapPropagateTLSState(T callback) {
  82. return [tls_state = ThreadLocalState(),
  83. callback = std::move(callback)](auto&&... args) {
  84. ThreadLocalStateGuard g(tls_state);
  85. // Propagate value returned by callback().
  86. return callback(std::forward<decltype(args)>(args)...);
  87. };
  88. }
  89. } // namespace at