OpaqueTensorImpl.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #pragma once
  2. #include <c10/core/MemoryFormat.h>
  3. #include <c10/core/SymIntArrayRef.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/util/Exception.h>
  6. namespace at {
  7. // An "Opaque" TensorImpl -- there are no strides and (for now)
  8. // even data() is not supported (thus no pointer arithmetic).
  9. // NOTE: We could allow data() in the future, but would have to ensure pointer
  10. // arithmetic code is properly guarded.
  11. //
  12. // NOTE: This does not support resize_ (and other metadata-changing ops) because
  13. // of `shallow_copy_and_detach`. We would need to define an interface to
  14. // "shallow copy" in order to add support.
  15. template <typename OpaqueHandle>
  16. struct TORCH_API OpaqueTensorImpl : public TensorImpl {
  17. // public constructor for now...
  18. OpaqueTensorImpl(
  19. at::DispatchKeySet key_set,
  20. const caffe2::TypeMeta data_type,
  21. c10::Device device,
  22. OpaqueHandle opaque_handle,
  23. c10::IntArrayRef sizes,
  24. bool is_non_overlapping_and_dense = true)
  25. : TensorImpl(key_set, data_type, device),
  26. opaque_handle_(std::move(opaque_handle)) {
  27. set_storage_access_should_throw();
  28. set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
  29. sizes_and_strides_.set_sizes(sizes);
  30. refresh_numel();
  31. // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
  32. is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
  33. }
  34. // Destructor doesn't call release_resources because it's
  35. // unnecessary; don't forget to change that if needed!
  36. void release_resources() override {
  37. TensorImpl::release_resources();
  38. opaque_handle_ = {};
  39. }
  40. void set_size(int64_t dim, int64_t new_size) override {
  41. AT_ERROR("opaque tensors do not have set_size");
  42. }
  43. void set_stride(int64_t dim, int64_t new_stride) override {
  44. AT_ERROR("opaque tensors do not have set_stride");
  45. }
  46. void set_storage_offset(int64_t storage_offset) override {
  47. AT_ERROR("opaque tensors do not have set_storage_offset");
  48. }
  49. #ifdef DEBUG
  50. bool has_storage() const override {
  51. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  52. !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
  53. return false;
  54. }
  55. #endif
  56. /**
  57. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  58. *
  59. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  60. * see NOTE [ TensorImpl Shallow-Copying ].
  61. */
  62. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  63. const c10::VariableVersion& version_counter,
  64. bool allow_tensor_metadata_change) const override {
  65. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  66. key_set(),
  67. dtype(),
  68. device(),
  69. opaque_handle_,
  70. sizes_and_strides_.sizes_arrayref());
  71. copy_tensor_metadata(
  72. /*src_opaque_impl=*/this,
  73. /*dest_opaque_impl=*/impl.get(),
  74. /*version_counter=*/version_counter,
  75. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  76. impl->refresh_numel();
  77. return impl;
  78. }
  79. /**
  80. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  81. *
  82. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  83. * see NOTE [ TensorImpl Shallow-Copying ].
  84. */
  85. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  86. c10::VariableVersion&& version_counter,
  87. bool allow_tensor_metadata_change) const override {
  88. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  89. key_set(),
  90. dtype(),
  91. device(),
  92. opaque_handle_,
  93. sizes_and_strides_.sizes_arrayref());
  94. copy_tensor_metadata(
  95. /*src_opaque_impl=*/this,
  96. /*dest_opaque_impl=*/impl.get(),
  97. /*version_counter=*/std::move(version_counter),
  98. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  99. impl->refresh_numel();
  100. return impl;
  101. }
  102. /**
  103. * Shallow-copies data from another TensorImpl into this TensorImpl.
  104. *
  105. * For why this function doesn't check this TensorImpl's
  106. * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
  107. */
  108. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  109. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  110. auto opaque_impl =
  111. static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
  112. copy_tensor_metadata(
  113. /*src_impl=*/opaque_impl,
  114. /*dest_impl=*/this,
  115. /*version_counter=*/version_counter(),
  116. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  117. refresh_numel();
  118. }
  119. const OpaqueHandle& opaque_handle() const {
  120. return opaque_handle_;
  121. }
  122. OpaqueHandle& unsafe_opaque_handle() {
  123. return opaque_handle_;
  124. }
  125. protected:
  126. /**
  127. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  128. * storage_offset) from one TensorImpl to another TensorImpl.
  129. *
  130. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  131. * [ TensorImpl Shallow-Copying ].
  132. */
  133. static void copy_tensor_metadata(
  134. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  135. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  136. const c10::VariableVersion& version_counter,
  137. bool allow_tensor_metadata_change) {
  138. TensorImpl::copy_tensor_metadata(
  139. src_opaque_impl,
  140. dest_opaque_impl,
  141. version_counter,
  142. allow_tensor_metadata_change);
  143. // OpaqueTensorImpl-specific fields.
  144. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  145. }
  146. static void copy_tensor_metadata(
  147. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  148. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  149. c10::VariableVersion&& version_counter,
  150. bool allow_tensor_metadata_change) {
  151. TensorImpl::copy_tensor_metadata(
  152. src_opaque_impl,
  153. dest_opaque_impl,
  154. std::move(version_counter),
  155. allow_tensor_metadata_change);
  156. // OpaqueTensorImpl-specific fields.
  157. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  158. }
  159. private:
  160. const char* tensorimpl_type_name() const override {
  161. return "OpaqueTensorImpl";
  162. }
  163. OpaqueHandle opaque_handle_;
  164. };
  165. } // namespace at