LegacyVmapMode.h 927 B

1234567891011121314151617181920212223242526
  1. #pragma once
  2. #include <c10/core/impl/LocalDispatchKeySet.h>
  3. namespace at::impl {
  4. // VmapMode contains a thread local count of how many nested vmaps
  5. // we are currently inside. That number is known as the `vmap level`.
  6. // VmapMode is used in the implementation of the Python `torch.vmap` API.
  7. //
  8. // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
  9. struct TORCH_API VmapMode {
  10. // Returns the vmap level, aka the count of how many nested vmaps we're in.
  11. static int64_t current_vmap_level();
  12. // Increment the count of nested vmaps. If this causes the vmap level to be
  13. // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
  14. static int64_t increment_nesting();
  15. // Decrements the count of nested vmaps. If this causes the vmap level to be
  16. // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
  17. static int64_t decrement_nesting();
  18. };
  19. } // namespace at::impl