MaybeOwned.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <c10/util/Exception.h>
  4. #include <memory>
  5. #include <type_traits>
  6. #include <utility>
  7. namespace c10 {
  8. /// MaybeOwnedTraits<T> describes how to borrow from T. Here is how we
  9. /// can implement borrowing from an arbitrary type T using a raw
  10. /// pointer to const:
  11. template <typename T>
  12. struct MaybeOwnedTraitsGenericImpl {
  13. using owned_type = T;
  14. using borrow_type = const T*;
  15. static borrow_type createBorrow(const owned_type& from) {
  16. return &from;
  17. }
  18. static void assignBorrow(borrow_type& lhs, borrow_type rhs) {
  19. lhs = rhs;
  20. }
  21. static void destroyBorrow(borrow_type& /*toDestroy*/) {}
  22. static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
  23. return *borrow;
  24. }
  25. static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
  26. return borrow;
  27. }
  28. static bool debugBorrowIsValid(const borrow_type& borrow) {
  29. return borrow != nullptr;
  30. }
  31. };
  32. /// It is possible to eliminate the extra layer of indirection for
  33. /// borrows for some types that we control. For examples, see
  34. /// intrusive_ptr.h and TensorBody.h.
  35. template <typename T>
  36. struct MaybeOwnedTraits;
  37. // Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing
  38. // MaybeOwned to be used for any type right away.
  39. template <typename T>
  40. struct MaybeOwnedTraits<std::shared_ptr<T>>
  41. : public MaybeOwnedTraitsGenericImpl<std::shared_ptr<T>> {};
  42. /// A smart pointer around either a borrowed or owned T. When
  43. /// constructed with borrowed(), the caller MUST ensure that the
  44. /// borrowed-from argument outlives this MaybeOwned<T>. Compare to
  45. /// Rust's std::borrow::Cow
  46. /// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note
  47. /// that it is probably not suitable for general use because C++ has
  48. /// no borrow checking. Included here to support
  49. /// Tensor::expect_contiguous.
  50. template <typename T>
  51. class MaybeOwned final {
  52. using borrow_type = typename MaybeOwnedTraits<T>::borrow_type;
  53. using owned_type = typename MaybeOwnedTraits<T>::owned_type;
  54. bool isBorrowed_;
  55. union {
  56. borrow_type borrow_;
  57. owned_type own_;
  58. };
  59. /// Don't use this; use borrowed() instead.
  60. explicit MaybeOwned(const owned_type& t)
  61. : isBorrowed_(true), borrow_(MaybeOwnedTraits<T>::createBorrow(t)) {}
  62. /// Don't use this; use owned() instead.
  63. explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v<T>)
  64. : isBorrowed_(false), own_(std::move(t)) {}
  65. /// Don't use this; use owned() instead.
  66. template <class... Args>
  67. explicit MaybeOwned(std::in_place_t, Args&&... args)
  68. : isBorrowed_(false), own_(std::forward<Args>(args)...) {}
  69. public:
  70. explicit MaybeOwned() : isBorrowed_(true), borrow_() {}
  71. // Copying a borrow yields another borrow of the original, as with a
  72. // T*. Copying an owned T yields another owned T for safety: no
  73. // chains of borrowing by default! (Note you could get that behavior
  74. // with MaybeOwned<T>::borrowed(*rhs) if you wanted it.)
  75. MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) {
  76. if (C10_LIKELY(rhs.isBorrowed_)) {
  77. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  78. } else {
  79. new (&own_) T(rhs.own_);
  80. }
  81. }
  82. MaybeOwned& operator=(const MaybeOwned& rhs) {
  83. if (this == &rhs) {
  84. return *this;
  85. }
  86. if (C10_UNLIKELY(!isBorrowed_)) {
  87. if (rhs.isBorrowed_) {
  88. own_.~T();
  89. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  90. isBorrowed_ = true;
  91. } else {
  92. own_ = rhs.own_;
  93. }
  94. } else {
  95. if (C10_LIKELY(rhs.isBorrowed_)) {
  96. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  97. } else {
  98. MaybeOwnedTraits<T>::destroyBorrow(borrow_);
  99. new (&own_) T(rhs.own_);
  100. isBorrowed_ = false;
  101. }
  102. }
  103. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_);
  104. return *this;
  105. }
  106. MaybeOwned(MaybeOwned&& rhs) noexcept(
  107. // NOLINTNEXTLINE(*-noexcept-move-*)
  108. std::is_nothrow_move_constructible_v<T> &&
  109. std::is_nothrow_move_assignable_v<borrow_type>)
  110. : isBorrowed_(rhs.isBorrowed_) {
  111. if (C10_LIKELY(rhs.isBorrowed_)) {
  112. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  113. } else {
  114. new (&own_) T(std::move(rhs.own_));
  115. }
  116. }
  117. MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(
  118. std::is_nothrow_move_assignable_v<T> &&
  119. std::is_nothrow_move_assignable_v<borrow_type> &&
  120. std::is_nothrow_move_constructible_v<T> &&
  121. // NOLINTNEXTLINE(*-noexcept-move-*)
  122. std::is_nothrow_destructible_v<T> &&
  123. std::is_nothrow_destructible_v<borrow_type>) {
  124. if (this == &rhs) {
  125. return *this;
  126. }
  127. if (C10_UNLIKELY(!isBorrowed_)) {
  128. if (rhs.isBorrowed_) {
  129. own_.~T();
  130. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  131. isBorrowed_ = true;
  132. } else {
  133. own_ = std::move(rhs.own_);
  134. }
  135. } else {
  136. if (C10_LIKELY(rhs.isBorrowed_)) {
  137. MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
  138. } else {
  139. MaybeOwnedTraits<T>::destroyBorrow(borrow_);
  140. new (&own_) T(std::move(rhs.own_));
  141. isBorrowed_ = false;
  142. }
  143. }
  144. return *this;
  145. }
  146. static MaybeOwned borrowed(const T& t) {
  147. return MaybeOwned(t);
  148. }
  149. static MaybeOwned owned(T&& t) noexcept(
  150. std::is_nothrow_move_constructible_v<T>) {
  151. return MaybeOwned(std::move(t));
  152. }
  153. template <class... Args>
  154. static MaybeOwned owned(std::in_place_t, Args&&... args) {
  155. return MaybeOwned(std::in_place, std::forward<Args>(args)...);
  156. }
  157. ~MaybeOwned() noexcept(
  158. // NOLINTNEXTLINE(*-noexcept-destructor)
  159. std::is_nothrow_destructible_v<T> &&
  160. std::is_nothrow_destructible_v<borrow_type>) {
  161. if (C10_UNLIKELY(!isBorrowed_)) {
  162. own_.~T();
  163. } else {
  164. MaybeOwnedTraits<T>::destroyBorrow(borrow_);
  165. }
  166. }
  167. // This is an implementation detail! You should know what you're doing
  168. // if you are testing this. If you just want to guarantee ownership move
  169. // this into a T
  170. bool unsafeIsBorrowed() const {
  171. return isBorrowed_;
  172. }
  173. const T& operator*() const& {
  174. if (isBorrowed_) {
  175. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  176. MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
  177. }
  178. return C10_LIKELY(isBorrowed_)
  179. ? MaybeOwnedTraits<T>::referenceFromBorrow(borrow_)
  180. : own_;
  181. }
  182. const T* operator->() const {
  183. if (isBorrowed_) {
  184. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  185. MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
  186. }
  187. return C10_LIKELY(isBorrowed_)
  188. ? MaybeOwnedTraits<T>::pointerFromBorrow(borrow_)
  189. : &own_;
  190. }
  191. // If borrowed, copy the underlying T. If owned, move from
  192. // it. borrowed/owned state remains the same, and either we
  193. // reference the same borrow as before or we are an owned moved-from
  194. // T.
  195. T operator*() && {
  196. if (isBorrowed_) {
  197. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  198. MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
  199. return MaybeOwnedTraits<T>::referenceFromBorrow(borrow_);
  200. } else {
  201. return std::move(own_);
  202. }
  203. }
  204. };
  205. } // namespace c10