PythonTorchFunctionTLS.h 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. #pragma once
  2. #include <c10/core/SafePyObject.h>
  3. #include <c10/macros/Macros.h>
  4. namespace at::impl {
  5. enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
  6. struct TORCH_API PythonTorchFunctionTLS {
  7. static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
  8. static TorchFunctionDisabledState get_disabled_state();
  9. static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
  10. static const std::shared_ptr<SafePyObject> pop_stack();
  11. static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
  12. static int64_t stack_len();
  13. static const PythonTorchFunctionTLS& get_state();
  14. static void set_state(const PythonTorchFunctionTLS& state);
  15. private:
  16. // The mode TLS is split into
  17. // - disabled_state, which says which part of torch function are disabled
  18. // - stack_, which is a vector of modes representing the stack of user
  19. // defined modes
  20. TorchFunctionDisabledState disabled_state_ =
  21. TorchFunctionDisabledState::ENABLED;
  22. std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
  23. };
  24. TORCH_API bool torch_function_mode_enabled();
  25. } // namespace at::impl