TensorIndexing.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. #pragma once
  2. #include <ATen/ExpandUtils.h>
  3. #include <ATen/ScalarOps.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/core/TensorBody.h>
  6. #include <c10/core/SymInt.h>
  7. #include <c10/util/Optional.h>
  8. #include <c10/util/irange.h>
  9. #ifndef AT_PER_OPERATOR_HEADERS
  10. #include <ATen/Functions.h>
  11. #include <ATen/NativeFunctions.h>
  12. #else
  13. #include <ATen/ops/alias.h>
  14. #include <ATen/ops/empty.h>
  15. #include <ATen/ops/scalar_tensor.h>
  16. #include <ATen/ops/zeros.h>
  17. #endif
  18. #include <ATen/core/List.h>
  19. #include <utility>
  20. namespace at::indexing {
  21. constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
  22. constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
  23. enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
  24. constexpr c10::nullopt_t None = c10::nullopt;
  25. struct TORCH_API EllipsisIndexType final {
  26. EllipsisIndexType() = default;
  27. };
  28. TORCH_API extern const EllipsisIndexType Ellipsis;
  29. struct TORCH_API Slice final {
  30. public:
  31. Slice(
  32. std::optional<c10::SymInt> start_index = c10::nullopt,
  33. std::optional<c10::SymInt> stop_index = c10::nullopt,
  34. std::optional<c10::SymInt> step_index = c10::nullopt) {
  35. if (!step_index.has_value()) {
  36. step_ = c10::SymInt(1);
  37. } else {
  38. step_ = std::move(step_index).value();
  39. }
  40. TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
  41. if (!start_index.has_value()) {
  42. start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
  43. } else {
  44. start_ = std::move(start_index).value();
  45. }
  46. if (!stop_index.has_value()) {
  47. stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
  48. } else {
  49. stop_ = std::move(stop_index).value();
  50. }
  51. }
  52. inline c10::SymInt start() const {
  53. return start_;
  54. }
  55. inline c10::SymInt stop() const {
  56. return stop_;
  57. }
  58. inline c10::SymInt step() const {
  59. return step_;
  60. }
  61. private:
  62. c10::SymInt start_;
  63. c10::SymInt stop_;
  64. c10::SymInt step_;
  65. };
  66. TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
  67. // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
  68. // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
  69. // into its equivalent `std::vector<TensorIndex>`, so that further tensor
  70. // indexing operations can be performed using the supplied indices.
  71. //
  72. // There is one-to-one correspondence between Python and C++ tensor index types:
  73. // Python | C++
  74. // -----------------------------------------------------
  75. // `None` | `at::indexing::None`
  76. // `Ellipsis` | `at::indexing::Ellipsis`
  77. // `...` | `"..."`
  78. // `123` | `123`
  79. // `True` / `False` | `true` / `false`
  80. // `:` | `Slice()` / `Slice(None, None)`
  81. // `::` | `Slice()` / `Slice(None, None, None)`
  82. // `1:` | `Slice(1, None)`
  83. // `1::` | `Slice(1, None, None)`
  84. // `:3` | `Slice(None, 3)`
  85. // `:3:` | `Slice(None, 3, None)`
  86. // `::2` | `Slice(None, None, 2)`
  87. // `1:3` | `Slice(1, 3)`
  88. // `1::2` | `Slice(1, None, 2)`
  89. // `:3:2` | `Slice(None, 3, 2)`
  90. // `1:3:2` | `Slice(1, 3, 2)`
  91. // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
  92. struct TORCH_API TensorIndex final {
  93. // Case 1: `at::indexing::None`
  94. TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
  95. // Case 2: "..." / `at::indexing::Ellipsis`
  96. TensorIndex(at::indexing::EllipsisIndexType)
  97. : type_(TensorIndexType::Ellipsis) {}
  98. TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
  99. TORCH_CHECK_VALUE(
  100. strcmp(str, "...") == 0,
  101. "Expected \"...\" to represent an ellipsis index, but got \"",
  102. str,
  103. "\"");
  104. }
  105. // Case 3: (Sym) Integer value
  106. TensorIndex(SymInt integer)
  107. : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
  108. TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
  109. TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
  110. // Case 4: Boolean value
  111. template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
  112. TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
  113. // Case 5: Slice represented in `at::indexing::Slice` form
  114. TensorIndex(Slice slice)
  115. : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
  116. // Case 6: Tensor value
  117. TensorIndex(Tensor tensor)
  118. : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
  119. inline bool is_none() const {
  120. return type_ == TensorIndexType::None;
  121. }
  122. inline bool is_ellipsis() const {
  123. return type_ == TensorIndexType::Ellipsis;
  124. }
  125. inline bool is_integer() const {
  126. return type_ == TensorIndexType::SymInt;
  127. }
  128. inline SymInt integer() const {
  129. return integer_;
  130. }
  131. inline bool is_boolean() const {
  132. return type_ == TensorIndexType::Boolean;
  133. }
  134. inline bool boolean() const {
  135. return boolean_;
  136. }
  137. inline bool is_slice() const {
  138. return type_ == TensorIndexType::Slice;
  139. }
  140. inline const Slice& slice() const {
  141. return slice_;
  142. }
  143. inline bool is_tensor() const {
  144. return type_ == TensorIndexType::Tensor;
  145. }
  146. inline const Tensor& tensor() const {
  147. return tensor_;
  148. }
  149. private:
  150. SymInt integer_ = 0;
  151. bool boolean_ = false;
  152. Slice slice_;
  153. Tensor tensor_;
  154. TensorIndexType type_;
  155. };
  156. TORCH_API std::ostream& operator<<(
  157. std::ostream& stream,
  158. const TensorIndex& tensor_index);
  159. TORCH_API std::ostream& operator<<(
  160. std::ostream& stream,
  161. const std::vector<TensorIndex>& tensor_indices);
  162. namespace impl {
  163. inline Tensor applySlice(
  164. const Tensor& self,
  165. int64_t dim,
  166. c10::SymInt start,
  167. c10::SymInt stop,
  168. c10::SymInt step,
  169. bool disable_slice_optimization,
  170. const at::Device& self_device,
  171. const std::optional<SymIntArrayRef>& self_sizes) {
  172. // TODO: implement negative step
  173. TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
  174. // See NOTE [nested tensor size for indexing]
  175. if (self_sizes.has_value()) {
  176. // Skip this optimization if we are tracing, as the trace may be polymorphic
  177. // over the shape of the `self` tensor, and we still want to record
  178. // the slice.
  179. SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
  180. ? (*self_sizes)[dim]
  181. : self.sym_size(dim);
  182. if (!disable_slice_optimization &&
  183. TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
  184. TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
  185. return self;
  186. }
  187. }
  188. return self.slice_symint(
  189. dim, std::move(start), std::move(stop), std::move(step));
  190. }
  191. inline Tensor applySelect(
  192. const Tensor& self,
  193. int64_t dim,
  194. SymInt index,
  195. int64_t real_dim,
  196. const at::Device& /*self_device*/,
  197. const std::optional<SymIntArrayRef>& self_sizes) {
  198. // See NOTE [nested tensor size for indexing]
  199. if (self_sizes.has_value()) {
  200. auto maybe_index = index.maybe_as_int();
  201. if (maybe_index.has_value()) {
  202. TORCH_CHECK_INDEX(
  203. !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
  204. "invalid index of a 0-dim tensor. ",
  205. "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
  206. }
  207. auto size = (*self_sizes)[dim];
  208. // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
  209. // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
  210. // minus is undefined by the standard but in practice is equal to self. On
  211. // the other hand, indexing wraping is valid for all negative int64_t
  212. // values, as x[INT64_MIN] is the same as x[INT64_MAX]
  213. TORCH_CHECK_INDEX(
  214. size > -1 - index && size > index,
  215. "index ",
  216. index,
  217. " is out of bounds for dimension ",
  218. real_dim,
  219. " with size ",
  220. size);
  221. }
  222. // if the index is negative, do not normalize it because that would fix the
  223. // index on the current tensor size in the tracer. aten::select also works on
  224. // negative indices
  225. return self.select_symint(dim, std::move(index));
  226. }
  227. inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
  228. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  229. // false as empty.
  230. if (value) {
  231. return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
  232. } else {
  233. return at::empty({0}, self.options().dtype(kLong));
  234. }
  235. }
  236. inline Tensor boolToIndexingTensorNonNativeDeviceType(
  237. const Tensor& self,
  238. bool value) {
  239. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  240. // false as empty.
  241. if (value) {
  242. return at::zeros({1}, self.options().dtype(kLong));
  243. } else {
  244. return at::empty({0}, self.options().dtype(kLong));
  245. }
  246. }
  247. inline Tensor boolToIndexingTensor(
  248. const Tensor& self,
  249. bool value,
  250. const at::Device& self_device) {
  251. if (self_device == at::kCPU || self_device == at::kCUDA) {
  252. return boolToIndexingTensorCPUOrCUDA(self, value);
  253. } else {
  254. return boolToIndexingTensorNonNativeDeviceType(self, value);
  255. }
  256. }
  257. inline Tensor scalarToTensorNonNativeDeviceType(
  258. const Scalar& v,
  259. const TensorOptions& options) {
  260. return at::scalar_tensor(v, options);
  261. }
  262. inline void recordTensorIndex(
  263. const Tensor& tensor,
  264. std::vector<Tensor>& outIndices,
  265. int64_t* dim_ptr) {
  266. // TODO: check scalarType
  267. outIndices.resize(*dim_ptr + 1);
  268. outIndices[*dim_ptr] = tensor;
  269. (*dim_ptr)++;
  270. };
  271. inline c10::List<::std::optional<Tensor>> typeConvertIndices(
  272. const Tensor& /*self*/,
  273. std::vector<Tensor>&& indices) {
  274. c10::List<::std::optional<Tensor>> converted_inds;
  275. converted_inds.reserve(indices.size());
  276. for (auto&& i : std::move(indices)) {
  277. converted_inds.push_back(std::move(i));
  278. }
  279. return converted_inds;
  280. }
  281. // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
  282. // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
  283. // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
  284. // indexing (i.e. it's called by `applySlicing` which is called by
  285. // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
  286. // than one dimension). If we were to merge the Python/C++
  287. // `count_specified_dimensions` function, on the Python side we would have to
  288. // construct a `std::vector` container to be consumed by the C++
  289. // `count_specified_dimensions` function, which adds 100s of nanoseconds
  290. // overhead and is undesirable.
  291. inline int64_t count_specified_dimensions(
  292. const ArrayRef<TensorIndex>& indices) {
  293. // Count the number of indexed dimensions (everything but ellipsis and None)
  294. int64_t count = 0;
  295. for (auto& obj : indices) {
  296. if (obj.is_tensor()) {
  297. auto& tensor = obj.tensor();
  298. if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
  299. count += tensor.dim();
  300. } else {
  301. count++;
  302. }
  303. } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
  304. count++;
  305. }
  306. }
  307. return count;
  308. }
  309. } // namespace impl
  310. // NOTE: Many functions below are only for consumption from Python indexing
  311. // implementation, they include:
  312. //
  313. // - `Tensor scalarToTensor(...)`
  314. // - `IntArrayRef slicePrefix1sSize(...)`
  315. // - `void copy_to(...)`
  316. // - `Tensor handleDimInMultiDimIndexing(...)`
  317. // - `Tensor dispatch_index(...)`
  318. // - `Tensor dispatch_index_put_(...)`
  319. // - `Tensor get_item(...)`
  320. // - `void set_item(...)`
  321. //
  322. // The rest of the functions are in `at::indexing::impl` namespace, signifying
  323. // that they shouldn't be used from Python indexing implementation.
  324. inline Tensor scalarToTensor(
  325. const Scalar& v,
  326. const TensorOptions& options,
  327. const at::Device& self_device) {
  328. if (self_device == at::kCPU && !v.isSymbolic()) {
  329. return at::detail::scalar_tensor_static(
  330. v, options.dtype_opt()->toScalarType(), self_device);
  331. } else {
  332. return impl::scalarToTensorNonNativeDeviceType(v, options);
  333. }
  334. }
  335. // To match numpy semantics:
  336. // As a special case for backwards compatibility,
  337. // strip away unit dimensions from the left of 'src'
  338. inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
  339. size_t first_non1_src = sizes.size();
  340. for (const auto i : c10::irange(sizes.size())) {
  341. // Unbacked SymInt has different behavior, but this is sound because
  342. // failing to slice will only ever cause an error, not divergent
  343. // behavior
  344. if (!sizes[i].has_hint() || sizes[i] != 1) {
  345. first_non1_src = i;
  346. break;
  347. }
  348. }
  349. return sizes.slice(first_non1_src);
  350. }
  351. inline void copy_to(const Tensor& dst, const Tensor& src) {
  352. if (dst.sym_sizes().equals(src.sym_sizes())) {
  353. // A shortcut to avoid generating hard-coded constant sizes during tracing.
  354. // This is not a perfect solution: when src & dst have different shapes,
  355. // constants will still appear. Users can workaround that case by
  356. // dst[index..] = src.reshape(..)
  357. dst.copy_(src);
  358. return;
  359. } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
  360. dst.fill_(src);
  361. return;
  362. }
  363. auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
  364. c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
  365. dst.copy_(*b_src);
  366. }
  367. // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
  368. // indexing functions from Python ]
  369. inline Tensor handleDimInMultiDimIndexing(
  370. const Tensor& prev_dim_result,
  371. const Tensor& original_tensor,
  372. const TensorIndex& index,
  373. int64_t* dim_ptr,
  374. int64_t* specified_dims_ptr,
  375. int64_t real_dim,
  376. std::vector<Tensor>& outIndices,
  377. bool disable_slice_optimization,
  378. const at::Device& original_tensor_device,
  379. const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
  380. if (index.is_integer()) {
  381. return impl::applySelect(
  382. prev_dim_result,
  383. *dim_ptr,
  384. index.integer(),
  385. real_dim,
  386. original_tensor_device,
  387. prev_dim_result_sizes);
  388. } else if (index.is_slice()) {
  389. Tensor result = impl::applySlice(
  390. prev_dim_result,
  391. *dim_ptr,
  392. index.slice().start(),
  393. index.slice().stop(),
  394. index.slice().step(),
  395. /*disable_slice_optimization=*/disable_slice_optimization,
  396. original_tensor_device,
  397. prev_dim_result_sizes);
  398. (*dim_ptr)++;
  399. return result;
  400. } else if (index.is_ellipsis()) {
  401. (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
  402. return prev_dim_result;
  403. } else if (index.is_none()) {
  404. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  405. (*dim_ptr)++;
  406. return result;
  407. } else if (index.is_boolean()) {
  408. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  409. impl::recordTensorIndex(
  410. impl::boolToIndexingTensor(
  411. result, index.boolean(), original_tensor_device),
  412. outIndices,
  413. dim_ptr);
  414. return result;
  415. } else if (index.is_tensor()) {
  416. Tensor result = prev_dim_result;
  417. const Tensor& tensor = index.tensor();
  418. auto scalar_type = tensor.scalar_type();
  419. if (tensor.dim() == 0 &&
  420. at::isIntegralType(scalar_type, /*includeBool=*/true)) {
  421. if (scalar_type != at::kByte && scalar_type != at::kBool) {
  422. result = impl::applySelect(
  423. result,
  424. *dim_ptr,
  425. tensor.item<int64_t>(),
  426. real_dim,
  427. original_tensor_device,
  428. prev_dim_result_sizes);
  429. } else {
  430. result = result.unsqueeze(*dim_ptr);
  431. if (scalar_type == at::kBool) {
  432. impl::recordTensorIndex(
  433. impl::boolToIndexingTensor(
  434. result, tensor.item<bool>() != 0, original_tensor_device),
  435. outIndices,
  436. dim_ptr);
  437. } else {
  438. impl::recordTensorIndex(
  439. impl::boolToIndexingTensor(
  440. result, tensor.item<uint8_t>() != 0, original_tensor_device),
  441. outIndices,
  442. dim_ptr);
  443. }
  444. }
  445. } else {
  446. impl::recordTensorIndex(tensor, outIndices, dim_ptr);
  447. }
  448. return result;
  449. } else {
  450. TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
  451. }
  452. }
  453. namespace impl {
  454. // This mirrors `applySlicing` in
  455. // torch/csrc/autograd/python_variable_indexing.cpp
  456. inline Tensor applySlicing(
  457. const Tensor& self,
  458. const ArrayRef<TensorIndex>& indices,
  459. std::vector<Tensor>& outIndices,
  460. bool disable_slice_optimization,
  461. const at::Device& self_device,
  462. const std::optional<SymIntArrayRef>& self_sizes) {
  463. int64_t dim = 0;
  464. int64_t specified_dims = impl::count_specified_dimensions(indices);
  465. // See NOTE [nested tensor size for indexing]
  466. if (self_sizes.has_value()) {
  467. TORCH_CHECK_INDEX(
  468. specified_dims <= (int64_t)self_sizes->size(),
  469. "too many indices for tensor of dimension ",
  470. (int)self_sizes->size());
  471. }
  472. Tensor result = self;
  473. for (const auto i : c10::irange(indices.size())) {
  474. auto& obj = indices[i];
  475. // See NOTE [nested tensor size for indexing]
  476. std::optional<SymIntArrayRef> result_sizes = result.is_nested()
  477. ? std::optional<SymIntArrayRef>(c10::nullopt)
  478. : std::optional<SymIntArrayRef>(result.sym_sizes());
  479. result = handleDimInMultiDimIndexing(
  480. /*prev_dim_result=*/result,
  481. /*original_tensor=*/self,
  482. /*index=*/obj,
  483. /*dim_ptr=*/&dim,
  484. /*specified_dims_ptr=*/&specified_dims,
  485. /*real_dim=*/static_cast<int64_t>(i),
  486. /*outIndices=*/outIndices,
  487. /*disable_slice_optimization=*/disable_slice_optimization,
  488. /*original_tensor_device=*/self_device,
  489. /*prev_dim_result_sizes=*/result_sizes);
  490. }
  491. return result;
  492. }
  493. } // namespace impl
  494. inline Tensor dispatch_index(
  495. const Tensor& self,
  496. std::vector<Tensor>&& indices) {
  497. return self.index(impl::typeConvertIndices(self, std::move(indices)));
  498. }
  499. inline Tensor dispatch_index_put_(
  500. Tensor& self,
  501. std::vector<Tensor>&& indices,
  502. const Tensor& value) {
  503. return self.index_put_(
  504. impl::typeConvertIndices(self, std::move(indices)), value);
  505. }
  506. // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
  507. // functions from Python ]
  508. //
  509. // Question: When should we set `disable_slice_optimization` to `true` when
  510. // calling C++ tensor indexing functions from Python indexing code?
  511. //
  512. // Answer: What "slice optimization" means: when we have a slicing expression
  513. // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
  514. // would skip dispatching the actual slice call as an optimization. However,
  515. // here are the cases where we DON'T want this optimization:
  516. //
  517. // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
  518. // Reason: we always return a shallow copy for expressions such as
  519. // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
  520. // :]`, we return an alias of `tensor` by doing the following:
  521. // ```
  522. // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
  523. // disable_slice_optimization, self_device, self_sizes); if
  524. // (tensorIndices.empty()) {
  525. // if (sliced.is_same(self)) {
  526. // // ensure we return a shallow copy for things like x[...]
  527. // sliced = at::alias(sliced);
  528. // }
  529. // return sliced;
  530. // }
  531. // ```)
  532. // 2. When we are doing JIT tracing.
  533. // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
  534. // slice operation.
  535. // This mirrors `THPVariable_getitem` in
  536. // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
  537. // `disable_slice_optimization` when calling C++ tensor indexing functions from
  538. // Python ]
  539. inline Tensor get_item(
  540. const Tensor& self,
  541. const ArrayRef<TensorIndex>& indices,
  542. bool disable_slice_optimization = false) {
  543. at::Device self_device = self.device();
  544. // NOTE [nested tensor size for indexing]
  545. // nested tensor does not have a size (yet) so for now we represent its size
  546. // as null may need to be changed after we reach a better solution for nested
  547. // tensor size
  548. std::optional<SymIntArrayRef> self_sizes = self.is_nested()
  549. ? std::optional<SymIntArrayRef>(c10::nullopt)
  550. : std::optional<SymIntArrayRef>(self.sym_sizes());
  551. // handle simple types: integers, slices, none, ellipsis, bool
  552. if (indices.size() == 1) {
  553. const TensorIndex& index = indices[0];
  554. if (index.is_integer()) {
  555. return impl::applySelect(
  556. self, 0, index.integer(), 0, self_device, self_sizes);
  557. } else if (index.is_slice()) {
  558. return impl::applySlice(
  559. self,
  560. 0,
  561. index.slice().start(),
  562. index.slice().stop(),
  563. index.slice().step(),
  564. /*disable_slice_optimization=*/true,
  565. self_device,
  566. self_sizes);
  567. } else if (index.is_none()) {
  568. return self.unsqueeze(0);
  569. } else if (index.is_ellipsis()) {
  570. return at::alias(self);
  571. } else if (index.is_boolean()) {
  572. Tensor result = self.unsqueeze(0);
  573. return dispatch_index(
  574. result,
  575. std::vector<Tensor>{impl::boolToIndexingTensor(
  576. result, index.boolean(), self_device)});
  577. }
  578. }
  579. std::vector<Tensor> tensorIndices;
  580. Tensor sliced = impl::applySlicing(
  581. self,
  582. indices,
  583. tensorIndices,
  584. disable_slice_optimization,
  585. self_device,
  586. self_sizes);
  587. if (tensorIndices.empty()) {
  588. if (sliced.is_same(self)) {
  589. // ensure we return a shallow copy for things like x[...]
  590. sliced = at::alias(sliced);
  591. }
  592. return sliced;
  593. }
  594. // indexing by tensors ("advanced" indexing)
  595. return dispatch_index(sliced, std::move(tensorIndices));
  596. }
  597. // This mirrors `THPVariable_setitem` in
  598. // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
  599. // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
  600. // tensor indexing functions from Python ]
  601. inline void set_item(
  602. const Tensor& self,
  603. const ArrayRef<TensorIndex>& indices,
  604. const Tensor& value,
  605. bool disable_slice_optimization = false) {
  606. at::Device self_device = self.device();
  607. SymIntArrayRef self_sizes = self.sym_sizes();
  608. // handle simple types: integers, slices, ellipsis, bool
  609. if (indices.size() == 1) {
  610. const TensorIndex& index = indices[0];
  611. if (index.is_boolean() && !index.boolean()) {
  612. // do nothing for false (technically we should check the size, but we
  613. // don't have real 0-sized shapes.
  614. return;
  615. } else if (index.is_ellipsis()) {
  616. copy_to(self, value);
  617. return;
  618. } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
  619. copy_to(self.unsqueeze(0), value);
  620. return;
  621. } else if (index.is_integer()) {
  622. copy_to(
  623. impl::applySelect(
  624. self, 0, index.integer(), 0, self_device, self_sizes),
  625. value);
  626. return;
  627. } else if (index.is_slice()) {
  628. copy_to(
  629. impl::applySlice(
  630. self,
  631. 0,
  632. index.slice().start(),
  633. index.slice().stop(),
  634. index.slice().step(),
  635. /*disable_slice_optimization=*/disable_slice_optimization,
  636. self_device,
  637. self_sizes),
  638. value);
  639. return;
  640. }
  641. }
  642. std::vector<Tensor> tensorIndices;
  643. Tensor sliced = impl::applySlicing(
  644. self,
  645. indices,
  646. tensorIndices,
  647. disable_slice_optimization,
  648. self_device,
  649. self_sizes);
  650. if (tensorIndices.empty()) {
  651. copy_to(sliced, value);
  652. return;
  653. }
  654. SymIntArrayRef valueSizes = value.sym_sizes();
  655. SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
  656. Tensor valuesSliced;
  657. if (!valueSizes.equals(slicedValueSizes)) {
  658. valuesSliced = value.view_symint(slicedValueSizes);
  659. } else {
  660. valuesSliced = value;
  661. }
  662. dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
  663. return;
  664. }
  665. } // namespace at::indexing