TensorUtils.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <ATen/EmptyTensor.h>
  4. #include <ATen/Tensor.h>
  5. #include <ATen/TensorGeometry.h>
  6. #include <ATen/Utils.h>
  7. #include <utility>
  8. // These functions are NOT in Utils.h, because this file has a dep on Tensor.h
  9. #define TORCH_CHECK_TENSOR_ALL(cond, ...) \
  10. TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
  11. namespace at {
  12. // The following are utility functions for checking that arguments
  13. // make sense. These are particularly useful for native functions,
  14. // which do NO argument checking by default.
  15. struct TORCH_API TensorArg {
  16. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  17. const Tensor& tensor;
  18. const char* name;
  19. int pos; // 1-indexed
  20. TensorArg(const Tensor& tensor, const char* name, int pos)
  21. : tensor(tensor), name(name), pos(pos) {}
  22. // Try to mitigate any possibility of dangling reference to temporaries.
  23. // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
  24. TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
  25. const Tensor* operator->() const {
  26. return &tensor;
  27. }
  28. const Tensor& operator*() const {
  29. return tensor;
  30. }
  31. };
  32. struct TORCH_API TensorGeometryArg {
  33. TensorGeometry tensor;
  34. const char* name;
  35. int pos; // 1-indexed
  36. /* implicit */ TensorGeometryArg(TensorArg arg)
  37. : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
  38. TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
  39. : tensor(std::move(tensor)), name(name), pos(pos) {}
  40. const TensorGeometry* operator->() const {
  41. return &tensor;
  42. }
  43. const TensorGeometry& operator*() const {
  44. return tensor;
  45. }
  46. };
  47. // A string describing which function did checks on its input
  48. // arguments.
  49. // TODO: Consider generalizing this into a call stack.
  50. using CheckedFrom = const char*;
  51. // The undefined convention: singular operators assume their arguments
  52. // are defined, but functions which take multiple tensors will
  53. // implicitly filter out undefined tensors (to make it easier to perform
  54. // tests which should apply if the tensor is defined, and should not
  55. // otherwise.)
  56. //
  57. // NB: This means that the n-ary operators take lists of TensorArg,
  58. // not TensorGeometryArg, because the Tensor to TensorGeometry
  59. // conversion will blow up if you have undefined tensors.
  60. TORCH_API std::ostream& operator<<(
  61. std::ostream& out,
  62. const TensorGeometryArg& t);
  63. TORCH_API void checkDim(
  64. CheckedFrom c,
  65. const Tensor& tensor,
  66. const char* name,
  67. int pos, // 1-indexed
  68. int64_t dim);
  69. TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
  70. // NB: this is an inclusive-exclusive range
  71. TORCH_API void checkDimRange(
  72. CheckedFrom c,
  73. const TensorGeometryArg& t,
  74. int64_t dim_start,
  75. int64_t dim_end);
  76. TORCH_API void checkSameDim(
  77. CheckedFrom c,
  78. const TensorGeometryArg& t1,
  79. const TensorGeometryArg& t2);
  80. TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
  81. TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
  82. TORCH_API void checkSize(
  83. CheckedFrom c,
  84. const TensorGeometryArg& t,
  85. IntArrayRef sizes);
  86. TORCH_API void checkSize_symint(
  87. CheckedFrom c,
  88. const TensorGeometryArg& t,
  89. c10::SymIntArrayRef sizes);
  90. TORCH_API void checkSize(
  91. CheckedFrom c,
  92. const TensorGeometryArg& t,
  93. int64_t dim,
  94. int64_t size);
  95. TORCH_API void checkSize_symint(
  96. CheckedFrom c,
  97. const TensorGeometryArg& t,
  98. int64_t dim,
  99. const c10::SymInt& size);
  100. TORCH_API void checkNumel(
  101. CheckedFrom c,
  102. const TensorGeometryArg& t,
  103. int64_t numel);
  104. TORCH_API void checkSameNumel(
  105. CheckedFrom c,
  106. const TensorArg& t1,
  107. const TensorArg& t2);
  108. TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
  109. TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
  110. TORCH_API void checkScalarTypes(
  111. CheckedFrom c,
  112. const TensorArg& t,
  113. at::ArrayRef<ScalarType> l);
  114. TORCH_API void checkSameGPU(
  115. CheckedFrom c,
  116. const TensorArg& t1,
  117. const TensorArg& t2);
  118. TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
  119. TORCH_API void checkSameType(
  120. CheckedFrom c,
  121. const TensorArg& t1,
  122. const TensorArg& t2);
  123. TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
  124. TORCH_API void checkSameSize(
  125. CheckedFrom c,
  126. const TensorArg& t1,
  127. const TensorArg& t2);
  128. TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors);
  129. TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
  130. TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
  131. // FixMe: does TensorArg slow things down?
  132. TORCH_API void checkBackend(
  133. CheckedFrom c,
  134. at::ArrayRef<Tensor> t,
  135. at::Backend backend);
  136. TORCH_API void checkDeviceType(
  137. CheckedFrom c,
  138. at::ArrayRef<Tensor> tensors,
  139. at::DeviceType device_type);
  140. TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
  141. TORCH_API void checkLayout(
  142. CheckedFrom c,
  143. at::ArrayRef<Tensor> tensors,
  144. at::Layout layout);
  145. // Methods for getting data_ptr if tensor is defined
  146. TORCH_API void* maybe_data_ptr(const Tensor& tensor);
  147. TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
  148. TORCH_API void check_dim_size(
  149. const Tensor& tensor,
  150. int64_t dim,
  151. int64_t dim_size,
  152. int64_t size);
  153. namespace detail {
  154. TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
  155. TORCH_API std::optional<std::vector<int64_t>> computeStride(
  156. IntArrayRef oldshape,
  157. IntArrayRef oldstride,
  158. IntArrayRef newshape);
  159. TORCH_API std::optional<SymDimVector> computeStride(
  160. c10::SymIntArrayRef oldshape,
  161. c10::SymIntArrayRef oldstride,
  162. c10::SymIntArrayRef newshape);
  163. TORCH_API std::optional<DimVector> computeStride(
  164. IntArrayRef oldshape,
  165. IntArrayRef oldstride,
  166. const DimVector& newshape);
  167. } // namespace detail
  168. } // namespace at