FuncTorchTLS.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <memory>
  4. namespace at::functorch {
  5. // NOTE [functorch TLS in pytorch/pytorch]
  6. //
  7. // functorch lives out-of-tree. However, it has some TLS that needs to be
  8. // propagated. The solution for that is we store a pointer to the TLS
  9. // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
  10. // include whatever functorch needs.
  11. //
  12. // We need to store a pointer due to the indirection:
  13. // inside functorch, we will create a subclass of FunctorchTLSBase called
  14. // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
  15. // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
  16. // yet.
  17. //
  18. // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
  19. // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
  20. // We can't directly pass around FunctorchTLSBase (without a pointer) because
  21. // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
  22. // more elements.
  23. struct TORCH_API FuncTorchTLSBase {
  24. virtual ~FuncTorchTLSBase() = default;
  25. virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
  26. virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
  27. virtual void checkSupportsCppAutogradFunction() const = 0;
  28. virtual void checkSupportsInplaceRequiresGrad() const = 0;
  29. virtual void checkSupportsRetainGrad() const = 0;
  30. };
  31. // returns deepcopy of the functorch tls
  32. TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
  33. // sets the functorch tls. always does a deep copy.
  34. TORCH_API void setFuncTorchTLS(
  35. const std::shared_ptr<const FuncTorchTLSBase>& state);
  36. // get a mutable reference to the functorch tls
  37. TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
  38. } // namespace at::functorch