HIPGuardImplMasqueradingAsCUDA.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #pragma once
  2. #include <ATen/hip/HIPConfig.h>
  3. // The includes of HIPGuard.h
  4. #include <c10/hip/impl/HIPGuardImpl.h>
  5. #include <c10/hip/HIPMacros.h>
  6. #include <c10/core/DeviceType.h>
  7. #include <c10/core/impl/InlineDeviceGuard.h>
  8. #include <c10/core/impl/InlineStreamGuard.h>
  9. #include <c10/util/Exception.h>
  10. #include <c10/hip/impl/HIPGuardImpl.h>
  11. #include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
  12. #include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
  13. // Use of c10::hip namespace here makes hipification easier, because
  14. // I don't have to also fix namespaces. Sorry!
  15. namespace c10 { namespace hip {
  16. // Note [Masquerading as CUDA]
  17. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
  18. // c10_hip is very easy to understand: it is HIPified from c10_cuda,
  19. // and anywhere you said CUDA, the source code now says HIP. HIPified
  20. // PyTorch is much harder to understand: it is HIPified from regular
  21. // PyTorch, yes, but NO source-to-source translation from CUDA to
  22. // HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
  23. // For example, when you use HIPified PyTorch, you say x.cuda() to
  24. // move a tensor onto ROCm device. We call this situation "HIP
  25. // masquerading as CUDA".
  26. //
  27. // This leads to a very awkward situation when we want to call c10_hip
  28. // code from PyTorch, since c10_hip is expecting things to be called
  29. // HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
  30. // fix this impedance mismatch, we have MasqueradingAsCUDA variants
  31. // for all c10_hip classes. These translate between the "HIP" and "CUDA
  32. // masquerading as HIP" worlds. For example,
  33. // HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
  34. // HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
  35. // returns CUDA, getDevice() reports the current HIP device as a CUDA
  36. // device.)
  37. //
  38. // We should be able to delete all of these classes entirely once
  39. // we switch PyTorch to calling a HIP a HIP.
  40. //
  41. // When you add a new MasqueradingAsCUDA class/function, you need to
  42. // also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
  43. //
  44. //
  45. //
  46. // By the way, note that the cpp file associated with this also
  47. // *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
  48. // this HIP implementation.
  49. struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
  50. static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
  51. HIPGuardImplMasqueradingAsCUDA() {}
  52. HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
  53. TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
  54. }
  55. c10::DeviceType type() const override {
  56. return c10::DeviceType::CUDA;
  57. }
  58. Device exchangeDevice(Device d) const override {
  59. TORCH_INTERNAL_ASSERT(d.is_cuda());
  60. Device old_device = getDevice();
  61. if (old_device.index() != d.index()) {
  62. C10_HIP_CHECK(hipSetDevice(d.index()));
  63. }
  64. return old_device;
  65. }
  66. Device getDevice() const override {
  67. int device;
  68. C10_HIP_CHECK(hipGetDevice(&device));
  69. return Device(c10::DeviceType::CUDA, device);
  70. }
  71. void setDevice(Device d) const override {
  72. TORCH_INTERNAL_ASSERT(d.is_cuda());
  73. C10_HIP_CHECK(hipSetDevice(d.index()));
  74. }
  75. void uncheckedSetDevice(Device d) const noexcept override {
  76. C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
  77. }
  78. Stream getStream(Device d) const noexcept override {
  79. return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
  80. }
  81. Stream getDefaultStream(Device d) const override {
  82. return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
  83. }
  84. Stream getNewStream(Device d, int priority = 0) const override {
  85. return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
  86. }
  87. Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
  88. return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
  89. }
  90. Stream exchangeStream(Stream s) const noexcept override {
  91. HIPStreamMasqueradingAsCUDA cs(s);
  92. auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
  93. setCurrentHIPStreamMasqueradingAsCUDA(cs);
  94. return old_stream.unwrap();
  95. }
  96. DeviceIndex deviceCount() const noexcept override {
  97. int deviceCnt;
  98. hipError_t _err;
  99. _err = hipGetDeviceCount(&deviceCnt);
  100. if(_err != hipErrorNoDevice && _err != hipSuccess)
  101. C10_HIP_CHECK(_err);
  102. return deviceCnt;
  103. }
  104. // Event-related functions
  105. // Note: hipEventCreateWithFlags should be called on the same device as
  106. // the recording stream's device.
  107. void createEvent(
  108. hipEvent_t* hip_event,
  109. const EventFlag flag) const {
  110. // Maps PyTorch's Event::Flag to HIP flag
  111. auto hip_flag = hipEventDefault;
  112. switch (flag) {
  113. case EventFlag::PYTORCH_DEFAULT:
  114. hip_flag = hipEventDisableTiming;
  115. break;
  116. case EventFlag::BACKEND_DEFAULT:
  117. hip_flag = hipEventDefault;
  118. break;
  119. default:
  120. TORCH_CHECK(false, "HIP event received unknown flag");
  121. }
  122. C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
  123. }
  124. void destroyEvent(
  125. void* event,
  126. const DeviceIndex device_index) const noexcept override {
  127. if (!event) return;
  128. auto hip_event = static_cast<hipEvent_t>(event);
  129. int orig_device;
  130. C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
  131. C10_HIP_CHECK_WARN(hipSetDevice(device_index));
  132. C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
  133. C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
  134. }
  135. void record(void** event,
  136. const Stream& stream,
  137. const DeviceIndex device_index,
  138. const EventFlag flag) const override {
  139. TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
  140. "Event device index ",
  141. device_index,
  142. " does not match recording stream's device index ",
  143. stream.device_index(),
  144. ".");
  145. hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
  146. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  147. // Moves to stream's device to record
  148. const auto orig_device = getDevice();
  149. setDevice(stream.device());
  150. // Creates the event (lazily)
  151. if (!hip_event) createEvent(&hip_event, flag);
  152. C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
  153. // Makes the void* point to the (possibly just allocated) HIP event
  154. *event = hip_event;
  155. // Resets device
  156. setDevice(orig_device);
  157. }
  158. void block(
  159. void* event,
  160. const Stream& stream) const override {
  161. if (!event) return;
  162. hipEvent_t hip_event = static_cast<hipEvent_t>(event);
  163. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  164. const auto orig_device = getDevice();
  165. setDevice(stream.device());
  166. C10_HIP_CHECK(hipStreamWaitEvent(
  167. hip_stream,
  168. hip_event,
  169. /*flags (must be zero)=*/ 0));
  170. setDevice(orig_device);
  171. }
  172. bool queryEvent(void* event) const override {
  173. if (!event) return true;
  174. hipEvent_t hip_event = static_cast<hipEvent_t>(event);
  175. const hipError_t err = hipEventQuery(hip_event);
  176. if (err != hipErrorNotReady) C10_HIP_CHECK(err);
  177. else {
  178. // ignore and clear the error if not ready
  179. (void)hipGetLastError();
  180. }
  181. return (err == hipSuccess);
  182. }
  183. // Stream-related functions
  184. bool queryStream(const Stream& stream) const override {
  185. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  186. return hip_stream.query();
  187. }
  188. void synchronizeStream(const Stream& stream) const override {
  189. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  190. hip_stream.synchronize();
  191. }
  192. void synchronizeEvent(void* event) const override {
  193. if (!event)
  194. return;
  195. hipEvent_t hip_event = static_cast<hipEvent_t>(event);
  196. C10_HIP_CHECK(hipEventSynchronize(hip_event));
  197. }
  198. void recordDataPtrOnStream(
  199. const c10::DataPtr& data_ptr,
  200. const Stream& stream) const override {
  201. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  202. HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
  203. }
  204. double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
  205. const override {
  206. TORCH_CHECK(
  207. event1 && event2,
  208. "Both events must be recorded before calculating elapsed time.");
  209. int orig_device;
  210. C10_HIP_CHECK(hipGetDevice(&orig_device));
  211. C10_HIP_CHECK(hipSetDevice(device_index));
  212. hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
  213. hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
  214. float time_ms = 0;
  215. // raise hipErrorNotReady if either event is recorded but not yet completed
  216. C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
  217. C10_HIP_CHECK(hipSetDevice(orig_device));
  218. return static_cast<double>(time_ms);
  219. }
  220. };
  221. // All of the guards which have HIPGuardImpl burned in need to also have
  222. // variants using HIPGuardImplMasqueradingAsCUDA.
  223. /// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
  224. /// the correct InlineDeviceGuard burned in. Sorry about the
  225. /// copy-pasting.
  226. struct HIPGuardMasqueradingAsCUDA {
  227. explicit HIPGuardMasqueradingAsCUDA() = delete;
  228. explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
  229. explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
  230. HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
  231. HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
  232. HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
  233. HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
  234. void set_device(Device device) { guard_.set_device(device); }
  235. void reset_device(Device device) { guard_.reset_device(device); }
  236. void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  237. Device original_device() const { return guard_.original_device(); }
  238. Device current_device() const { return guard_.current_device(); }
  239. private:
  240. c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  241. };
  242. struct OptionalHIPGuardMasqueradingAsCUDA {
  243. explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
  244. explicit OptionalHIPGuardMasqueradingAsCUDA(optional<Device> device_opt) : guard_(device_opt) {}
  245. explicit OptionalHIPGuardMasqueradingAsCUDA(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
  246. OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
  247. OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
  248. OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
  249. OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
  250. void set_device(Device device) { guard_.set_device(device); }
  251. void reset_device(Device device) { guard_.reset_device(device); }
  252. void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  253. optional<Device> original_device() const { return guard_.original_device(); }
  254. optional<Device> current_device() const { return guard_.current_device(); }
  255. void reset() { guard_.reset(); }
  256. private:
  257. c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  258. };
  259. struct HIPStreamGuardMasqueradingAsCUDA {
  260. explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
  261. explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
  262. HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
  263. HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
  264. HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  265. HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  266. void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  267. HIPStreamMasqueradingAsCUDA original_stream() const {
  268. return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
  269. }
  270. HIPStreamMasqueradingAsCUDA current_stream() const {
  271. return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
  272. }
  273. Device current_device() const { return guard_.current_device(); }
  274. Device original_device() const { return guard_.original_device(); }
  275. private:
  276. c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  277. };
  278. struct OptionalHIPStreamGuardMasqueradingAsCUDA {
  279. explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
  280. explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
  281. explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional<Stream> stream_opt) : guard_(stream_opt) {}
  282. OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
  283. OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
  284. OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  285. OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  286. void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  287. optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
  288. auto r = guard_.original_stream();
  289. if (r.has_value()) {
  290. return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
  291. } else {
  292. return nullopt;
  293. }
  294. }
  295. optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
  296. auto r = guard_.current_stream();
  297. if (r.has_value()) {
  298. return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
  299. } else {
  300. return nullopt;
  301. }
  302. }
  303. void reset() { guard_.reset(); }
  304. private:
  305. c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  306. };
  307. struct HIPMultiStreamGuardMasqueradingAsCUDA {
  308. explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
  309. : guard_(unwrapStreams(streams)) {}
  310. HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
  311. HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
  312. HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
  313. HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
  314. private:
  315. c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  316. static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
  317. std::vector<Stream> streams;
  318. streams.reserve(hipStreams.size());
  319. for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
  320. streams.push_back(hipStream);
  321. }
  322. return streams;
  323. }
  324. };
  325. }} // namespace c10::hip