MPSGuardImpl.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <c10/core/impl/DeviceGuardImplInterface.h>
  4. #include <c10/macros/Macros.h>
  5. #include <c10/util/Exception.h>
  6. #include <ATen/Context.h>
  7. #include <ATen/mps/MPSStream.h>
  8. #include <ATen/mps/MPSEvent.h>
  9. #ifdef __OBJC__
  10. #include <Foundation/Foundation.h>
  11. #include <Metal/Metal.h>
  12. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  13. #endif
  14. #include <ATen/Tensor.h>
  15. #include <c10/core/MemoryFormat.h>
  16. #include <c10/core/Storage.h>
  17. #include <c10/core/TensorImpl.h>
  18. #include <sys/_types/_size_t.h>
  19. #include <memory>
  20. #include <c10/core/UndefinedTensorImpl.h>
  21. #include <c10/util/intrusive_ptr.h>
  22. namespace at::mps {
  23. typedef MPSEvent* mpsEvent_t;
  24. // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
  25. // https://github.com/pytorch/pytorch/issues/77170
  26. struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  27. static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
  28. // constructor
  29. MPSGuardImpl() {}
  30. explicit MPSGuardImpl(c10::DeviceType t) {
  31. TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
  32. }
  33. // returns the type
  34. c10::DeviceType type() const override {
  35. return c10::DeviceType::MPS;
  36. }
  37. Device exchangeDevice(Device d) const override {
  38. return Device(c10::DeviceType::MPS, 0);
  39. }
  40. Device getDevice() const override {
  41. return Device(c10::DeviceType::MPS, 0);
  42. }
  43. std::optional<Device> uncheckedGetDevice() const noexcept {
  44. return Device(c10::DeviceType::MPS, 0);
  45. }
  46. void setDevice(Device d) const override {
  47. TORCH_INTERNAL_ASSERT(d.is_mps());
  48. }
  49. void uncheckedSetDevice(Device d) const noexcept override {
  50. // TODO: Currently setting only device 0
  51. }
  52. Stream getStream(Device d) const noexcept override {
  53. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  54. }
  55. Stream getDefaultStream(Device d) const override {
  56. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  57. }
  58. // NB: These do NOT set the current device
  59. Stream exchangeStream(Stream s) const noexcept override {
  60. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  61. }
  62. DeviceIndex deviceCount() const noexcept override {
  63. if (at::hasMPS()) {
  64. //TODO: extend it for multi-device case
  65. return 1;
  66. } else {
  67. return 0;
  68. }
  69. }
  70. // Event-related functions
  71. void createEvent(
  72. mpsEvent_t* event,
  73. const EventFlag flag) const;
  74. void destroyEvent(
  75. void* event,
  76. const DeviceIndex device_index) const noexcept override;
  77. void record(
  78. void** event,
  79. const Stream& stream,
  80. const DeviceIndex device_index,
  81. const EventFlag flag) const override;
  82. void block(
  83. void* event,
  84. const Stream& stream) const override;
  85. bool queryEvent(void* event) const override;
  86. };
  87. /// A variant of OptionalDeviceGuard that is specialized for MPS.
  88. struct OptionalMPSGuard {
  89. explicit OptionalMPSGuard() : guard_() {}
  90. explicit OptionalMPSGuard(std::optional<Device> device_opt)
  91. : guard_(device_opt) {}
  92. /// Set the current MPS device to the passed device index, if it is not
  93. /// nullopt
  94. explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
  95. : guard_(device_index_opt) {}
  96. // Copy is not allowed
  97. OptionalMPSGuard(const OptionalMPSGuard&) = delete;
  98. OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
  99. OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
  100. OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
  101. /// Sets the MPS device to the given device, initializing the guard if it
  102. /// is not already initialized. Errors if the given device is not a MPS
  103. /// device.
  104. void set_device(Device device) {
  105. guard_.set_device(device);
  106. }
  107. /// Sets the MPS device to the given device, initializing the guard if it is
  108. /// not already initialized. Errors if the given device is not a MPS device.
  109. void reset_device(Device device) {
  110. guard_.reset_device(device);
  111. }
  112. /// Sets the MPS device to the given device index, initializing the guard if
  113. /// it is not already initialized.
  114. void set_index(DeviceIndex device_index) {
  115. guard_.set_index(device_index);
  116. }
  117. /// Returns the device that was set immediately prior to initialization of the
  118. /// guard, or nullopt if the guard is uninitialized.
  119. std::optional<Device> original_device() const {
  120. return guard_.original_device();
  121. }
  122. /// Returns the most recent device that was set using this device guard,
  123. /// either from construction, or via set_device, if the guard is initialized,
  124. /// or nullopt if the guard is uninitialized.
  125. std::optional<Device> current_device() const {
  126. return guard_.current_device();
  127. }
  128. /// Restore the original MPS device, resetting this guard to uninitialized
  129. /// state.
  130. void reset() {
  131. guard_.reset();
  132. }
  133. private:
  134. c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
  135. };
  136. C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
  137. } // namespace at::mps