ExpandUtils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. #pragma once
  2. #ifndef AT_PER_OPERATOR_HEADERS
  3. #include <ATen/Functions.h>
  4. #else
  5. #include <ATen/ops/view.h>
  6. #include <ATen/ops/view_copy.h>
  7. #endif
  8. #include <ATen/Tensor.h>
  9. #include <ATen/core/DimVector.h>
  10. #include <c10/util/Exception.h>
  11. #include <c10/util/MaybeOwned.h>
  12. #include <c10/util/irange.h>
  13. #include <functional>
  14. #include <sstream>
  15. #include <tuple>
  16. #include <utility>
  17. namespace at {
  18. TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
  19. TORCH_API std::vector<SymInt> infer_size_symint(
  20. SymIntArrayRef a,
  21. SymIntArrayRef b);
  22. TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
  23. TORCH_API SymDimVector
  24. infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
  25. // Named type instead of a pair/tuple so that we can be sure to
  26. // construct the vectors in place and get NRVO.
  27. template <typename Container>
  28. struct InferExpandGeometryResult {
  29. Container sizes;
  30. Container strides;
  31. explicit InferExpandGeometryResult(size_t ndim)
  32. : sizes(ndim), strides(ndim) {}
  33. explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
  34. : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
  35. };
  36. TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
  37. inferExpandGeometry(
  38. IntArrayRef tensor_sizes,
  39. IntArrayRef tensor_strides,
  40. IntArrayRef sizes);
  41. TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
  42. IntArrayRef tensor_sizes,
  43. IntArrayRef tensor_strides,
  44. IntArrayRef sizes);
  45. TORCH_API std::vector<int64_t> infer_dense_strides(
  46. IntArrayRef tensor_sizes,
  47. IntArrayRef tensor_strides);
  48. // True if input shapes are expandable
  49. // NOTE: infer_size did a similar check, please keep them sync if change is
  50. // needed
  51. inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
  52. size_t ndim1 = shape1.size();
  53. size_t ndim2 = shape2.size();
  54. size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
  55. for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
  56. if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
  57. shape2[ndim2] == 1) {
  58. continue;
  59. }
  60. return false;
  61. }
  62. return true;
  63. }
  64. // avoid copy-construction of Tensor by using a reference_wrapper.
  65. inline void check_defined(
  66. std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
  67. const char* api_name) {
  68. for (auto& t : tensors) {
  69. if (!t.get().defined()) {
  70. AT_ERROR(api_name, "(...) called with an undefined Tensor");
  71. }
  72. }
  73. }
  74. // NOTE [ ExpandUtils Borrowing ]
  75. //
  76. // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
  77. // expansion may not actually be needed, in which case we can improve
  78. // efficiency by returning
  79. // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
  80. // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
  81. // must not outlive the original `Tensor` object that `to_expand`
  82. // referred to! The deleted rvalue reference overloads of these
  83. // functions help with this by preventing trivial use of a temporary
  84. // resulting from a function call, but it is still possible to make a
  85. // mistake.
  86. inline c10::MaybeOwned<Tensor> expand_inplace(
  87. const Tensor& tensor,
  88. const Tensor& to_expand) {
  89. if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
  90. return c10::MaybeOwned<Tensor>::borrowed(to_expand);
  91. }
  92. return c10::MaybeOwned<Tensor>::owned(
  93. to_expand.expand_symint(tensor.sym_sizes()));
  94. }
  95. inline c10::MaybeOwned<Tensor> expand_inplace(
  96. const Tensor& tensor,
  97. Tensor&& to_expand) = delete;
  98. inline c10::MaybeOwned<Tensor> expand_inplace(
  99. const Tensor& tensor,
  100. const Tensor& to_expand,
  101. const char* api_name) {
  102. check_defined({tensor, to_expand}, api_name);
  103. return expand_inplace(tensor, to_expand);
  104. }
  105. inline c10::MaybeOwned<Tensor> expand_inplace(
  106. const Tensor& tensor,
  107. Tensor&& to_expand,
  108. const char* api_name) = delete;
  109. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  110. expand_inplace(
  111. const Tensor& tensor,
  112. const Tensor& to_expand1,
  113. const Tensor& to_expand2) {
  114. if (tensor.sizes().equals(to_expand1.sizes()) &&
  115. tensor.sizes().equals((to_expand2.sizes()))) {
  116. return std::make_tuple(
  117. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  118. c10::MaybeOwned<Tensor>::borrowed(to_expand2));
  119. }
  120. return std::make_tuple(
  121. c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
  122. c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
  123. }
  124. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  125. expand_inplace(
  126. const Tensor& tensor,
  127. Tensor&& to_expand1,
  128. const Tensor& to_expand2) = delete;
  129. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  130. expand_inplace(
  131. const Tensor& tensor,
  132. const Tensor& to_expand1,
  133. Tensor&& to_expand2) = delete;
  134. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  135. expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
  136. delete;
  137. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  138. expand_inplace(
  139. const Tensor& tensor,
  140. const Tensor& to_expand1,
  141. const Tensor& to_expand2,
  142. const char* api_name) {
  143. check_defined({tensor, to_expand1, to_expand2}, api_name);
  144. return expand_inplace(tensor, to_expand1, to_expand2);
  145. }
  146. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  147. expand_inplace(
  148. const Tensor& tensor,
  149. Tensor&& to_expand1,
  150. const Tensor& to_expand2,
  151. const char* api_name) = delete;
  152. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  153. expand_inplace(
  154. const Tensor& tensor,
  155. const Tensor& to_expand1,
  156. Tensor&& to_expand2,
  157. const char* api_name) = delete;
  158. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  159. expand_inplace(
  160. const Tensor& tensor,
  161. Tensor&& to_expand1,
  162. Tensor&& to_expand2,
  163. const char* api_name) = delete;
  164. // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
  165. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  166. expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
  167. auto s1 = to_expand1.sym_sizes();
  168. auto s2 = to_expand2.sym_sizes();
  169. if (s1.equals(s2)) {
  170. return std::make_tuple(
  171. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  172. c10::MaybeOwned<Tensor>::borrowed(to_expand2));
  173. }
  174. auto expanded_size = infer_size_symdimvector(s1, s2);
  175. return std::make_tuple(
  176. c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
  177. c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
  178. }
  179. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  180. expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
  181. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  182. expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
  183. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  184. expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
  185. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  186. expand_outplace(
  187. const Tensor& to_expand1,
  188. const Tensor& to_expand2,
  189. const char* api_name) {
  190. check_defined({to_expand1, to_expand2}, api_name);
  191. return expand_outplace(to_expand1, to_expand2);
  192. }
  193. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  194. expand_outplace(
  195. Tensor&& to_expand1,
  196. const Tensor& to_expand2,
  197. const char* api_name) = delete;
  198. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  199. expand_outplace(
  200. const Tensor& to_expand1,
  201. Tensor&& to_expand2,
  202. const char* api_name) = delete;
  203. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  204. expand_outplace(
  205. Tensor&& to_expand1,
  206. Tensor&& to_expand2,
  207. const char* api_name) = delete;
  208. inline std::tuple<
  209. c10::MaybeOwned<Tensor>,
  210. c10::MaybeOwned<Tensor>,
  211. c10::MaybeOwned<Tensor>>
  212. expand_outplace(
  213. const Tensor& to_expand1,
  214. const Tensor& to_expand2,
  215. const Tensor& to_expand3) {
  216. if (to_expand1.sizes().equals(to_expand2.sizes()) &&
  217. to_expand1.sizes().equals(to_expand3.sizes())) {
  218. return std::make_tuple(
  219. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  220. c10::MaybeOwned<Tensor>::borrowed(to_expand2),
  221. c10::MaybeOwned<Tensor>::borrowed(to_expand3));
  222. }
  223. auto expanded_size12 =
  224. infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
  225. auto expanded_size =
  226. infer_size_dimvector(expanded_size12, to_expand3.sizes());
  227. return std::make_tuple(
  228. c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
  229. c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
  230. c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
  231. }
  232. inline std::tuple<
  233. c10::MaybeOwned<Tensor>,
  234. c10::MaybeOwned<Tensor>,
  235. c10::MaybeOwned<Tensor>>
  236. expand_outplace(
  237. Tensor&& to_expand1,
  238. const Tensor& to_expand2,
  239. const Tensor& to_expand3) = delete;
  240. inline std::tuple<
  241. c10::MaybeOwned<Tensor>,
  242. c10::MaybeOwned<Tensor>,
  243. c10::MaybeOwned<Tensor>>
  244. expand_outplace(
  245. const Tensor& to_expand1,
  246. Tensor&& to_expand2,
  247. const Tensor& to_expand3) = delete;
  248. inline std::tuple<
  249. c10::MaybeOwned<Tensor>,
  250. c10::MaybeOwned<Tensor>,
  251. c10::MaybeOwned<Tensor>>
  252. expand_outplace(
  253. Tensor&& to_expand1,
  254. Tensor&& to_expand2,
  255. const Tensor& to_expand3) = delete;
  256. inline std::tuple<
  257. c10::MaybeOwned<Tensor>,
  258. c10::MaybeOwned<Tensor>,
  259. c10::MaybeOwned<Tensor>>
  260. expand_outplace(
  261. const Tensor& to_expand1,
  262. const Tensor& to_expand2,
  263. Tensor&& to_expand3) = delete;
  264. inline std::tuple<
  265. c10::MaybeOwned<Tensor>,
  266. c10::MaybeOwned<Tensor>,
  267. c10::MaybeOwned<Tensor>>
  268. expand_outplace(
  269. Tensor&& to_expand1,
  270. const Tensor& to_expand2,
  271. Tensor&& to_expand3) = delete;
  272. inline std::tuple<
  273. c10::MaybeOwned<Tensor>,
  274. c10::MaybeOwned<Tensor>,
  275. c10::MaybeOwned<Tensor>>
  276. expand_outplace(
  277. const Tensor& to_expand1,
  278. Tensor&& to_expand2,
  279. Tensor&& to_expand3) = delete;
  280. inline std::tuple<
  281. c10::MaybeOwned<Tensor>,
  282. c10::MaybeOwned<Tensor>,
  283. c10::MaybeOwned<Tensor>>
  284. expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
  285. delete;
  286. inline std::tuple<
  287. c10::MaybeOwned<Tensor>,
  288. c10::MaybeOwned<Tensor>,
  289. c10::MaybeOwned<Tensor>>
  290. expand_outplace(
  291. const Tensor& to_expand1,
  292. const Tensor& to_expand2,
  293. const Tensor& to_expand3,
  294. const char* api_name) {
  295. check_defined({to_expand1, to_expand2, to_expand3}, api_name);
  296. return expand_outplace(to_expand1, to_expand2, to_expand3);
  297. }
  298. inline std::tuple<
  299. c10::MaybeOwned<Tensor>,
  300. c10::MaybeOwned<Tensor>,
  301. c10::MaybeOwned<Tensor>>
  302. expand_outplace(
  303. Tensor&& to_expand1,
  304. const Tensor& to_expand2,
  305. const Tensor& to_expand3,
  306. const char* api_name) = delete;
  307. inline std::tuple<
  308. c10::MaybeOwned<Tensor>,
  309. c10::MaybeOwned<Tensor>,
  310. c10::MaybeOwned<Tensor>>
  311. expand_outplace(
  312. const Tensor& to_expand1,
  313. Tensor&& to_expand2,
  314. const Tensor& to_expand3,
  315. const char* api_name) = delete;
  316. inline std::tuple<
  317. c10::MaybeOwned<Tensor>,
  318. c10::MaybeOwned<Tensor>,
  319. c10::MaybeOwned<Tensor>>
  320. expand_outplace(
  321. Tensor&& to_expand1,
  322. Tensor&& to_expand2,
  323. const Tensor& to_expand3,
  324. const char* api_name) = delete;
  325. inline std::tuple<
  326. c10::MaybeOwned<Tensor>,
  327. c10::MaybeOwned<Tensor>,
  328. c10::MaybeOwned<Tensor>>
  329. expand_outplace(
  330. const Tensor& to_expand1,
  331. const Tensor& to_expand2,
  332. Tensor&& to_expand3,
  333. const char* api_name) = delete;
  334. inline std::tuple<
  335. c10::MaybeOwned<Tensor>,
  336. c10::MaybeOwned<Tensor>,
  337. c10::MaybeOwned<Tensor>>
  338. expand_outplace(
  339. Tensor&& to_expand1,
  340. const Tensor& to_expand2,
  341. Tensor&& to_expand3,
  342. const char* api_name) = delete;
  343. inline std::tuple<
  344. c10::MaybeOwned<Tensor>,
  345. c10::MaybeOwned<Tensor>,
  346. c10::MaybeOwned<Tensor>>
  347. expand_outplace(
  348. const Tensor& to_expand1,
  349. Tensor&& to_expand2,
  350. Tensor&& to_expand3,
  351. const char* api_name) = delete;
  352. inline std::tuple<
  353. c10::MaybeOwned<Tensor>,
  354. c10::MaybeOwned<Tensor>,
  355. c10::MaybeOwned<Tensor>>
  356. expand_outplace(
  357. Tensor&& to_expand1,
  358. Tensor&& to_expand2,
  359. Tensor&& to_expand3,
  360. const char* api_name) = delete;
  361. inline c10::MaybeOwned<Tensor> expand_size(
  362. const Tensor& to_expand,
  363. IntArrayRef sizes) {
  364. if (to_expand.sizes().equals(sizes)) {
  365. return c10::MaybeOwned<Tensor>::borrowed(to_expand);
  366. }
  367. return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
  368. }
  369. inline c10::MaybeOwned<Tensor> expand_size(
  370. Tensor&& to_expand,
  371. IntArrayRef sizes) = delete;
  372. inline c10::MaybeOwned<Tensor> expand_size(
  373. const Tensor& to_expand,
  374. IntArrayRef sizes,
  375. const char* api_name) {
  376. check_defined({to_expand}, api_name);
  377. return expand_size(to_expand, sizes);
  378. }
  379. inline c10::MaybeOwned<Tensor> expand_size(
  380. Tensor&& to_expand,
  381. IntArrayRef sizes,
  382. const char* api_name) = delete;
  383. inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
  384. // expands a list of Tensors; ignores undefined (null) tensors
  385. bool first = true;
  386. DimVector sizes;
  387. for (const auto i : c10::irange(to_expand.size())) {
  388. if (!to_expand[i].defined()) {
  389. continue;
  390. } else if (first) {
  391. sizes = to_expand[i].sizes();
  392. first = false;
  393. } else {
  394. sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
  395. }
  396. }
  397. std::vector<Tensor> result(to_expand.size());
  398. for (const auto i : c10::irange(to_expand.size())) {
  399. if (!to_expand[i].defined()) {
  400. continue;
  401. } else if (to_expand[i].sizes().equals(sizes)) {
  402. result[i] = to_expand[i];
  403. } else {
  404. result[i] = to_expand[i].expand(sizes);
  405. }
  406. }
  407. return result;
  408. }
  409. template <typename T>
  410. inline Tensor _sum_to(
  411. Tensor tensor,
  412. const c10::ArrayRef<T> shape,
  413. bool always_return_non_view = false) {
  414. if (shape.size() == 0) {
  415. return tensor.sum();
  416. }
  417. auto sizes = at::symint::sizes<T>(tensor);
  418. c10::SmallVector<int64_t, 8> reduce_dims;
  419. const int64_t leading_dims = sizes.size() - shape.size();
  420. for (const auto i : c10::irange(leading_dims)) {
  421. reduce_dims.push_back(i);
  422. }
  423. for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
  424. if (shape[i - leading_dims] == 1 &&
  425. TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
  426. reduce_dims.push_back(i);
  427. }
  428. }
  429. if (!reduce_dims.empty()) {
  430. tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
  431. }
  432. if (always_return_non_view) {
  433. // This is only actually used by the functionalization pass.
  434. // We want to be able to guarantee that this function doesn't return a view
  435. // of the input.
  436. return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
  437. : tensor.clone();
  438. } else {
  439. return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
  440. }
  441. }
  442. inline Tensor sum_to(
  443. Tensor tensor,
  444. const c10::SymIntArrayRef shape,
  445. bool always_return_non_view = false) {
  446. return _sum_to(std::move(tensor), shape, always_return_non_view);
  447. }
  448. // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
  449. // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
  450. inline Tensor sum_to(
  451. Tensor tensor,
  452. const IntArrayRef shape,
  453. bool always_return_non_view = false) {
  454. return _sum_to(std::move(tensor), shape, always_return_non_view);
  455. }
  456. static inline bool is_expandable_to(
  457. SymIntArrayRef shape,
  458. c10::SymIntArrayRef desired) {
  459. size_t ndim = shape.size();
  460. size_t target_dim = desired.size();
  461. if (ndim > target_dim) {
  462. return false;
  463. }
  464. for (const auto i : c10::irange(ndim)) {
  465. const auto& size = shape[ndim - i - 1];
  466. const auto& target = desired[target_dim - i - 1];
  467. if (size != target && size != 1) {
  468. return false;
  469. }
  470. }
  471. return true;
  472. }
  473. static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
  474. auto sym_shape = c10::SymIntArrayRef(
  475. reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
  476. auto sym_desired = c10::SymIntArrayRef(
  477. reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
  478. return is_expandable_to(sym_shape, sym_desired);
  479. }
  480. } // namespace at