HIPStreamMasqueradingAsCUDA.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #pragma once
  2. #include <c10/hip/HIPStream.h>
  3. // Use of c10::hip namespace here makes hipification easier, because
  4. // I don't have to also fix namespaces. Sorry!
  5. namespace c10 { namespace hip {
  6. // See Note [Masquerading as CUDA] for motivation
  7. class HIPStreamMasqueradingAsCUDA {
  8. public:
  9. enum Unchecked { UNCHECKED };
  10. explicit HIPStreamMasqueradingAsCUDA(Stream stream)
  11. : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
  12. // We did the coercion unchecked; check that it was right.
  13. TORCH_CHECK(stream.device().is_cuda() /* !!! */);
  14. }
  15. explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
  16. // Unsafely coerce the "CUDA" stream into a HIP stream
  17. : stream_(
  18. HIPStream(
  19. Stream(
  20. Stream::UNSAFE,
  21. Device(c10::DeviceType::HIP, stream.device_index()),
  22. stream.id())
  23. )
  24. ) {}
  25. // New constructor, just for this. Does NOT coerce.
  26. explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
  27. bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
  28. return stream_ == other.stream_;
  29. }
  30. bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
  31. return stream_ != other.stream_;
  32. }
  33. operator hipStream_t() const { return stream_.stream(); }
  34. operator Stream() const {
  35. // Unsafely coerce HIP stream into a "CUDA" stream
  36. return Stream(Stream::UNSAFE, device(), id());
  37. }
  38. DeviceIndex device_index() const { return stream_.device_index(); }
  39. // Unsafely coerce HIP device into CUDA device
  40. c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
  41. Device device() const {
  42. // Unsafely coerce HIP device into CUDA device
  43. return Device(c10::DeviceType::CUDA, stream_.device_index());
  44. }
  45. StreamId id() const { return stream_.id(); }
  46. bool query() const { return stream_.query(); }
  47. void synchronize() const { stream_.synchronize(); }
  48. int priority() const { return stream_.priority(); }
  49. hipStream_t stream() const { return stream_.stream(); }
  50. Stream unwrap() const {
  51. // Unsafely coerce HIP stream into "CUDA" stream
  52. return Stream(Stream::UNSAFE, device(), id());
  53. }
  54. c10::StreamData3 pack3() const noexcept {
  55. // Unsafely coerce HIP stream into "CUDA" stream before packing
  56. return unwrap().pack3();
  57. }
  58. static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
  59. DeviceIndex device_index,
  60. c10::DeviceType device_type) {
  61. // NB: constructor manages CUDA->HIP translation for us
  62. return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
  63. stream_id, device_index, device_type));
  64. }
  65. static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
  66. // New method, gets the underlying HIPStream
  67. HIPStream hip_stream() const { return stream_; }
  68. private:
  69. HIPStream stream_;
  70. };
  71. HIPStreamMasqueradingAsCUDA
  72. inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
  73. return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
  74. }
  75. HIPStreamMasqueradingAsCUDA
  76. inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
  77. return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
  78. }
  79. HIPStreamMasqueradingAsCUDA
  80. inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
  81. return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
  82. }
  83. inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
  84. return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
  85. }
  86. inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
  87. return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
  88. }
  89. inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
  90. setCurrentHIPStream(stream.hip_stream());
  91. }
  92. inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
  93. stream << s.hip_stream() << " (masquerading as CUDA)";
  94. return stream;
  95. }
  96. }} // namespace c10::hip
  97. namespace std {
  98. template <>
  99. struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
  100. size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
  101. return std::hash<c10::Stream>{}(s.unwrap());
  102. }
  103. };
  104. } // namespace std