DeviceAccelerator.h 932 B

123456789101112131415161718192021222324252627
  1. #pragma once
  2. #include <c10/core/DeviceType.h>
  3. #include <c10/macros/Macros.h>
  4. #include <ATen/detail/MTIAHooksInterface.h>
  5. #include <optional>
  6. // This file defines the top level Accelerator concept for PyTorch.
  7. // A device is an accelerator per the definition here if:
  8. // - It is mutually exclusive with all other accelerators
  9. // - It performs asynchronous compute via a Stream/Event system
  10. // - It provides a set of common APIs as defined by AcceleratorHooksInterface
  11. //
  12. // As of today, accelerator devices are (in no particular order):
  13. // CUDA, MTIA, PrivateUse1
  14. // We want to add once all the proper APIs are supported and tested:
  15. // HIP, MPS, XPU
  16. namespace at {
  17. // Ensures that only one accelerator is available (at
  18. // compile time if possible) and return it.
  19. // When checked is true, the returned optional always has a value.
  20. TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
  21. } // namespace at