XPUGuardImpl.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #pragma once
  2. #include <c10/core/DeviceGuard.h>
  3. #include <c10/core/impl/DeviceGuardImplInterface.h>
  4. #include <c10/core/impl/GPUTrace.h>
  5. #include <c10/xpu/XPUCachingAllocator.h>
  6. #include <c10/xpu/XPUFunctions.h>
  7. #include <c10/xpu/XPUStream.h>
  8. #include <vector>
  9. namespace c10::xpu::impl {
  10. struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  11. static constexpr DeviceType static_type = kXPU;
  12. XPUGuardImpl() = default;
  13. explicit XPUGuardImpl(DeviceType t) {
  14. TORCH_INTERNAL_ASSERT(t == kXPU);
  15. }
  16. DeviceType type() const override {
  17. return kXPU;
  18. }
  19. Device exchangeDevice(Device d) const override {
  20. TORCH_INTERNAL_ASSERT(d.is_xpu());
  21. const auto old_device_index = c10::xpu::exchange_device(d.index());
  22. return Device(kXPU, old_device_index);
  23. }
  24. Device getDevice() const override {
  25. const auto device = c10::xpu::current_device();
  26. return Device(kXPU, device);
  27. }
  28. void setDevice(Device d) const override {
  29. TORCH_INTERNAL_ASSERT(d.is_xpu());
  30. c10::xpu::set_device(d.index());
  31. }
  32. void uncheckedSetDevice(Device d) const noexcept override {
  33. c10::xpu::set_device(d.index());
  34. }
  35. Stream getStream(Device d) const noexcept override {
  36. return getCurrentXPUStream(d.index()).unwrap();
  37. }
  38. Stream getNewStream(Device d, int priority = 0) const override {
  39. return getStreamFromPool(priority, d.index());
  40. }
  41. Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
  42. const override {
  43. return getStreamFromPool(isHighPriority, d.index());
  44. }
  45. // NB: These do NOT set the current device
  46. Stream exchangeStream(Stream s) const noexcept override {
  47. const XPUStream stream(s);
  48. const auto old_stream = getCurrentXPUStream(s.device().index());
  49. setCurrentXPUStream(stream);
  50. return old_stream.unwrap();
  51. }
  52. DeviceIndex deviceCount() const noexcept override {
  53. return c10::xpu::device_count();
  54. }
  55. // Event-related functions
  56. void destroyEvent(void* event, const DeviceIndex device_index)
  57. const noexcept override {
  58. if (!event)
  59. return;
  60. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  61. if (C10_UNLIKELY(interp)) {
  62. (*interp)->trace_gpu_event_deletion(
  63. c10::kXPU, reinterpret_cast<uintptr_t>(event));
  64. }
  65. delete reinterpret_cast<sycl::event*>(event);
  66. }
  67. void record(
  68. void** event,
  69. const Stream& stream,
  70. const DeviceIndex device_index,
  71. const EventFlag flag) const override {
  72. TORCH_CHECK(
  73. device_index == -1 || device_index == stream.device_index(),
  74. "Event device index ",
  75. device_index,
  76. " does not match recording stream's device index ",
  77. stream.device_index(),
  78. ".");
  79. auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
  80. const XPUStream xpu_stream{stream};
  81. // Delete the event previously recorded.
  82. if (xpu_event)
  83. delete xpu_event;
  84. xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
  85. *event = reinterpret_cast<void*>(xpu_event);
  86. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  87. if (C10_UNLIKELY(interp)) {
  88. (*interp)->trace_gpu_event_record(
  89. c10::kXPU,
  90. reinterpret_cast<uintptr_t>(xpu_event),
  91. reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
  92. }
  93. }
  94. void block(void* event, const Stream& stream) const override {
  95. if (!event)
  96. return;
  97. auto* xpu_event = reinterpret_cast<sycl::event*>(event);
  98. std::vector<sycl::event> event_list{*xpu_event};
  99. const XPUStream xpu_stream(stream);
  100. xpu_stream.queue().ext_oneapi_submit_barrier(event_list);
  101. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  102. if (C10_UNLIKELY(interp)) {
  103. (*interp)->trace_gpu_event_wait(
  104. c10::kXPU,
  105. reinterpret_cast<uintptr_t>(xpu_event),
  106. reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
  107. }
  108. }
  109. bool queryEvent(void* event) const override {
  110. using namespace sycl::info;
  111. if (!event)
  112. return true;
  113. auto* xpu_event = reinterpret_cast<sycl::event*>(event);
  114. return xpu_event->get_info<event::command_execution_status>() ==
  115. event_command_status::complete;
  116. }
  117. // Stream-related functions
  118. bool queryStream(const Stream& stream) const override {
  119. const XPUStream xpu_stream{stream};
  120. return xpu_stream.query();
  121. }
  122. void synchronizeStream(const Stream& stream) const override {
  123. const XPUStream xpu_stream{stream};
  124. xpu_stream.synchronize();
  125. }
  126. void synchronizeEvent(void* event) const override {
  127. if (!event)
  128. return;
  129. auto* xpu_event = reinterpret_cast<sycl::event*>(event);
  130. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  131. if (C10_UNLIKELY(interp)) {
  132. (*interp)->trace_gpu_event_synchronization(
  133. c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event));
  134. }
  135. xpu_event->wait_and_throw();
  136. }
  137. void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
  138. const override {
  139. const XPUStream xpu_stream{stream};
  140. XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
  141. }
  142. double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
  143. const override {
  144. TORCH_CHECK_NOT_IMPLEMENTED(
  145. false, "elapsedTime is not supported by XPU backend.");
  146. }
  147. };
  148. } // namespace c10::xpu::impl