TensorIterator.h 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. #pragma once
  2. #include <ATen/TensorMeta.h>
  3. #include <ATen/core/Dimname.h>
  4. #include <ATen/core/Range.h>
  5. #include <ATen/core/TensorBase.h>
  6. #include <c10/core/DynamicCast.h>
  7. #include <c10/util/FunctionRef.h>
  8. #include <c10/util/MaybeOwned.h>
  9. #include <c10/util/SmallVector.h>
  10. #include <c10/util/TypeCast.h>
  11. #include <c10/util/irange.h>
  12. #include <array>
  13. #include <bitset>
  14. namespace at {
  15. class Tensor;
  16. class OptionalTensorRef;
  17. using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
  18. } // namespace at
  19. // TensorIterator is a helper class for element-wise operations, such as
  20. // arithmetic, comparisons, and trigonometric functions. It handles
  21. // broadcasting and type conversions of operands.
  22. //
  23. // This is inspired by NumPy's Array Iterator API (NpyIter).
  24. //
  25. // The files Loops.h and Loops.cuh provide functions to build kernels that
  26. // use TensorIterator.
  27. //
  28. // Example:
  29. //
  30. // auto iter = TensorIteratorConfig()
  31. // .add_output(output)
  32. // .add_input(input)
  33. // .build()
  34. //
  35. // [MyKernel.cpp / MyKernel.cu]
  36. // cpu_kernel(iter, [](float a, float b) {
  37. // return a + b;
  38. // });
  39. //
  40. // gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
  41. // return a + b;
  42. // });
  43. //
  44. // Note [Order of Construction]
  45. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  46. // When setting up the tensor iterator configuration, the output Tensors
  47. // have to be added first via
  48. // TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs,
  49. // the inputs can be added via
  50. // TensorIteratorConfig::add_owned_input(at::Tensor).
  51. // Adding another output after inputs have been added will rise an exception.
  52. //
  53. // Note [Common Dtype Computation]
  54. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  55. // Some operations have a natural notion of a "common dtype" or
  56. // "computation dtype" where all inputs are cast to one dtype, the
  57. // operation is performed, and then the results are cast to all outputs.
  58. //
  59. // TensorIterator infers a common dtype if all inputs have the same dtype,
  60. // and it computes one using type promotion rules on its inputs if
  61. // promote_inputs_to_common_dtype_ is true. Attempting to query
  62. // a common dtype otherwise will throw an exception.
  63. //
  64. // Note that the outputs are not considered when computing a common dtype.
  65. namespace at {
  66. namespace internal {
  67. // This parameter is heuristically chosen to determine the minimum number of
  68. // work that warrants parallelism. For example, when summing an array, it is
  69. // deemed inefficient to parallelise over arrays shorter than 32768. Further,
  70. // no parallel algorithm (such as parallel_reduce) should split work into
  71. // smaller than GRAIN_SIZE chunks.
  72. constexpr int64_t GRAIN_SIZE = 32768;
  73. // Storage for a non-owning Tensor, without needing to include Tensor.h
  74. class TORCH_API OpaqueOptionalTensorRef {
  75. alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_{};
  76. public:
  77. OpaqueOptionalTensorRef();
  78. OpaqueOptionalTensorRef(const OpaqueOptionalTensorRef&) = default;
  79. OpaqueOptionalTensorRef& operator=(const OpaqueOptionalTensorRef&) = default;
  80. OpaqueOptionalTensorRef(OpaqueOptionalTensorRef&&) noexcept = default;
  81. OpaqueOptionalTensorRef& operator=(OpaqueOptionalTensorRef&&) noexcept =
  82. default;
  83. ~OpaqueOptionalTensorRef();
  84. OptionalTensorRef* get() {
  85. return reinterpret_cast<OptionalTensorRef*>(data_.data());
  86. }
  87. const OptionalTensorRef* get() const {
  88. return reinterpret_cast<const OptionalTensorRef*>(data_.data());
  89. }
  90. OptionalTensorRef& operator*() {
  91. return *get();
  92. }
  93. const OptionalTensorRef& operator*() const {
  94. return *get();
  95. }
  96. OptionalTensorRef* operator->() {
  97. return get();
  98. }
  99. const OptionalTensorRef* operator->() const {
  100. return get();
  101. }
  102. const Tensor& getTensor() const;
  103. };
  104. } // namespace internal
  105. struct TORCH_API OperandInfo {
  106. using StrideVector = SmallVector<int64_t, 6>;
  107. OperandInfo() = default;
  108. C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) {
  109. if (t->defined()) {
  110. device = t->device();
  111. target_dtype = t->scalar_type();
  112. current_dtype = target_dtype;
  113. }
  114. tensor(std::move(t));
  115. validate();
  116. }
  117. C10_ALWAYS_INLINE OperandInfo(const OperandInfo&) = default;
  118. C10_ALWAYS_INLINE OperandInfo& operator=(const OperandInfo&) = default;
  119. C10_ALWAYS_INLINE OperandInfo(OperandInfo&&) noexcept = default;
  120. C10_ALWAYS_INLINE OperandInfo& operator=(OperandInfo&&) noexcept = default;
  121. C10_ALWAYS_INLINE ~OperandInfo() = default;
  122. /// The data pointer. This may be different from tensor->data_ptr() if the
  123. /// iterator is split.
  124. void* data = nullptr;
  125. /// Stride after broadcasting. The stride is in bytes, not number of elements.
  126. StrideVector stride_bytes;
  127. /// The desired device and type for the operand. For inputs, this specifies
  128. /// that the input should be converted to this type if necessary. For outputs,
  129. /// this specifies which type to allocate. target_dtype and device are
  130. /// initialized with the dtype and device of the tensor but during type
  131. /// promotion target_dtype value can become different from tensor's dtype
  132. /// also, during type promotion target_dtype and device can be set for an
  133. /// undefined tensor so that tensor can be properly constructed later.
  134. std::optional<Device> device = c10::nullopt;
  135. ScalarType target_dtype = ScalarType::Undefined;
  136. // Caches dtype of the tensor, because scalar_type is an expensive operation
  137. // If dtype of the tensor is changed (e.g. as a result of type promotion or in
  138. // allocate_outputs), this
  139. // value should be changed too.
  140. ScalarType current_dtype = ScalarType::Undefined;
  141. bool is_device_defined() const {
  142. return device.has_value();
  143. }
  144. bool is_type_defined() const {
  145. return target_dtype != ScalarType::Undefined;
  146. }
  147. TensorOptions options() const {
  148. return TensorOptions(target_dtype).device(device);
  149. }
  150. bool is_output = false;
  151. // will_resize is only for output tensor.
  152. // 1) Functional call(like torch.add(self, other)): output tensor is
  153. // undefined, and pytorch creates a new tensor by using common shape
  154. // and computed stride in TensorIterator;
  155. // 2) Inplace call(like torch.add_(self, other)): output tensor is same
  156. // with input tensor, and can't to modify tensor's size and stride;
  157. // 3) Op call with output(like torch.add(self, other, out = output)):
  158. // output tensor is defined, but tensor shape maybe different with common
  159. // shape. If tensor shape is not same with common shape, this output
  160. // tensor will be resized by using common shape and computed stride in
  161. // TensorIterator. Otherwise can't modify tensor's size and stride.
  162. bool will_resize = false;
  163. bool is_read_write = false;
  164. bool is_const = false;
  165. void validate() {
  166. TORCH_CHECK(
  167. !tensor_base_->defined() || tensor_base_->layout() == kStrided,
  168. "unsupported tensor layout: ",
  169. tensor_base_->layout());
  170. }
  171. /// The tensor operand. Note that the strides, data pointer, and
  172. /// other attributes may differ due to dimension reordering and
  173. /// coalescing.
  174. const Tensor& tensor() const {
  175. return tensor_storage_.getTensor();
  176. }
  177. const TensorBase& tensor_base() const {
  178. return *tensor_base_;
  179. }
  180. void tensor(c10::MaybeOwned<TensorBase>&& tensor);
  181. // Save the original tensor operand in cases when an output is modified
  182. // (e.g. if dtype is changed)
  183. const Tensor& original_tensor() const {
  184. return original_tensor_storage_.getTensor();
  185. }
  186. const TensorBase& original_tensor_base() const {
  187. return *original_tensor_base_;
  188. }
  189. // Set tensor to a new value, and store the old tensor value in
  190. // original_tensor Should only ever be called once for the lifetime of an
  191. // operand
  192. void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor);
  193. // Move original_tensor back into tensor, exchange_tensor must have been
  194. // called before
  195. void restore_original_tensor();
  196. private:
  197. c10::MaybeOwned<TensorBase> tensor_base_;
  198. c10::MaybeOwned<TensorBase> original_tensor_base_ =
  199. c10::MaybeOwned<TensorBase>::owned(std::in_place);
  200. // We store TensorBase visibly in the header to allow inline access.
  201. // However, we sometimes need a genuine `const Tensor &` for the
  202. // TensorIterator API. So, we also store a non-owning `Tensor`
  203. // object in these `_storage_` variables.
  204. internal::OpaqueOptionalTensorRef tensor_storage_;
  205. internal::OpaqueOptionalTensorRef original_tensor_storage_;
  206. };
  207. struct SplitUntil32Bit;
  208. enum class FastSetupType : uint8_t {
  209. NONE,
  210. CONTIGUOUS,
  211. CHANNELS_LAST,
  212. NON_OVERLAPPING_DENSE
  213. };
  214. class TensorIteratorConfig;
  215. struct TensorIterator;
  216. struct TORCH_API TensorIteratorBase : public impl::MetaBase {
  217. using DimMask = std::bitset<64>;
  218. using PtrVector = SmallVector<char*, 4>;
  219. using StrideVector = SmallVector<int64_t, 6>;
  220. TensorIteratorBase();
  221. void build(TensorIteratorConfig&);
  222. // The inner-loop function operates on the fastest moving dimension. It
  223. // implements element-wise operations in terms of 1-d strided tensors.
  224. //
  225. // Arguments:
  226. // data: data pointers for each operand (length `ntensors`)
  227. // strides: stride for each operand (length `ntensors`)
  228. // size: size of inner loop
  229. //
  230. // The `size` often matches shape[0], but may be smaller due to
  231. // parallelization of the inner loop.
  232. using loop2d_t = c10::function_ref<
  233. void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
  234. using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;
  235. void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true);
  236. int ndim() const {
  237. return static_cast<int>(shape_.size());
  238. }
  239. IntArrayRef shape() const {
  240. return shape_;
  241. }
  242. int64_t numel() const;
  243. int ntensors() const {
  244. return static_cast<int>(operands_.size());
  245. }
  246. int noutputs() const {
  247. return num_outputs_;
  248. }
  249. int ninputs() const {
  250. return ntensors() - noutputs();
  251. }
  252. IntArrayRef view_offsets() const {
  253. return view_offsets_;
  254. }
  255. /// number of elements in the output operand. this is the same as numel() for
  256. /// operations that are not reductions.
  257. int64_t num_output_elements() const;
  258. /// number of reduced dimensions in a reduction operation
  259. int num_reduce_dims() const;
  260. /// 1-dimensional iteration and no buffering or type conversion
  261. bool is_trivial_1d() const;
  262. /// Reducible to 1-dimensional and all operands are contiguous
  263. bool is_contiguous() const;
  264. bool is_dim_reduced(int dim) const;
  265. /// Accessors for each operand
  266. IntArrayRef strides(int64_t arg) const {
  267. return operands_[arg].stride_bytes;
  268. }
  269. void* data_ptr(int64_t arg) const;
  270. ScalarType dtype(int64_t arg = 0) const {
  271. return operands_[arg].current_dtype;
  272. }
  273. ScalarType common_dtype() const {
  274. TORCH_INTERNAL_ASSERT(
  275. common_dtype_ != ScalarType::Undefined,
  276. "Queried for invalid common dtype!");
  277. return common_dtype_;
  278. }
  279. ScalarType input_dtype(int64_t arg = 0) const {
  280. return operands_[num_outputs_ + arg].current_dtype;
  281. }
  282. Device device(int64_t arg = 0) const {
  283. return operands_[arg].device.value();
  284. }
  285. c10::DeviceType device_type(int64_t arg = 0) const {
  286. return device(arg).type();
  287. }
  288. int64_t element_size(int64_t arg) const {
  289. return static_cast<int64_t>(elementSize(dtype(arg)));
  290. }
  291. bool is_scalar(int64_t arg) const;
  292. bool is_cpu_scalar(int64_t arg) const;
  293. const TensorBase& tensor_base(int64_t arg) const {
  294. return operands_[arg].tensor_base();
  295. }
  296. const Tensor& tensor(int64_t arg) const {
  297. return operands_[arg].tensor();
  298. }
  299. const TensorBase& output_base(int64_t arg = 0) const {
  300. AT_ASSERT(arg < num_outputs_);
  301. return tensor_base(arg);
  302. }
  303. const Tensor& output(int64_t arg = 0) const {
  304. AT_ASSERT(arg < num_outputs_);
  305. return tensor(arg);
  306. }
  307. const TensorBase& input_base(int64_t arg = 0) const {
  308. AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
  309. return tensor_base(num_outputs_ + arg);
  310. }
  311. const Tensor& input(int64_t arg = 0) const {
  312. AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
  313. return tensor(num_outputs_ + arg);
  314. }
  315. // Copies from temporary outputs back to the original outputs
  316. // NOTE: only used on CPU
  317. void cast_outputs();
  318. /// Removes an operand from this iterator
  319. void remove_operand(int64_t arg);
  320. /// Shrinks an iterated dimension
  321. void narrow(int dim, int64_t start, int64_t size);
  322. /// Narrows every dim after and including `start_dim` to size one.
  323. void select_all_keeping_dim(int start_dim, IntArrayRef starts);
  324. /// Replaces the data pointer for the operand at index `arg`.
  325. /// The new pointer should have the same sizes, strides and dtype as the
  326. /// original
  327. void unsafe_replace_operand(int64_t arg, void* data);
  328. /// Splits this TensorIterator into two iterators. Together they iterate over
  329. /// the entire operation. Used by `with_32bit_indexing()`.
  330. std::unique_ptr<TensorIterator> split(int dim);
  331. /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
  332. int get_dim_to_split() const;
  333. template <typename T>
  334. T scalar_value(int64_t arg) {
  335. auto& op = operands_[arg];
  336. return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data);
  337. }
  338. /// Return scalar value from original_tensor_base if it is defined. When
  339. /// common_dtype is Half, casting scalar input to common_dtype might overflow.
  340. /// If the scalar is aleady given in the type of Half, then return scalar
  341. /// value from tensor_base.
  342. template <typename T>
  343. T original_scalar_value(int64_t arg) {
  344. auto& original_tensor_base = operands_[arg].original_tensor_base();
  345. if (original_tensor_base.defined()) {
  346. TORCH_INTERNAL_ASSERT(
  347. original_tensor_base.scalar_type() != common_dtype());
  348. return c10::fetch_and_cast<T>(
  349. original_tensor_base.scalar_type(),
  350. original_tensor_base.const_data_ptr());
  351. } else {
  352. return scalar_value<T>(arg);
  353. }
  354. }
  355. private:
  356. template <typename loop1d_t>
  357. auto loop_2d_from_1d(const loop1d_t& loop) {
  358. return
  359. [loop, ntensor = ntensors()](
  360. char** base, const int64_t* strides, int64_t size0, int64_t size1) {
  361. PtrVector data(base, base + ntensor);
  362. const int64_t* outer_strides = &strides[ntensor];
  363. for (const auto i : c10::irange(size1)) {
  364. if (i > 0) {
  365. for (const auto arg : c10::irange(ntensor)) {
  366. data[arg] += outer_strides[arg];
  367. }
  368. }
  369. loop(data.data(), strides, size0);
  370. }
  371. };
  372. }
  373. public:
  374. template <
  375. typename loop1d_t,
  376. std::enable_if_t<
  377. std::is_convertible_v<
  378. loop1d_t,
  379. c10::function_ref<
  380. void(char**, const int64_t* strides, int64_t size)>>,
  381. int> = 0>
  382. void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) {
  383. for_each(loop_2d_from_1d(loop), grain_size);
  384. }
  385. void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
  386. void parallel_reduce(loop2d_t loop);
  387. template <
  388. typename loop1d_t,
  389. std::enable_if_t<
  390. std::is_convertible_v<
  391. loop1d_t,
  392. c10::function_ref<
  393. void(char**, const int64_t* strides, int64_t size)>>,
  394. int> = 0>
  395. void serial_for_each(loop1d_t loop, Range range) {
  396. serial_for_each(loop_2d_from_1d(loop), range);
  397. }
  398. void serial_for_each(loop2d_t loop, Range range) const;
  399. /// Create a strides array for a Tensor with shape of this iterator. The
  400. /// parameter `element_size` specifies the size of Tensor's data type in
  401. /// bytes (e.g. `4` for `float`)
  402. StrideVector compatible_stride(int64_t element_size) const;
  403. /// Inverts the re-ordering done by reorder_dimensions. This can only be
  404. /// called *before* coalesce_dimensions() is called.
  405. DimVector invert_perm(IntArrayRef input) const;
  406. /// Reapply same re-ordering as it is done by reorder_dimensions. This can
  407. /// only be called *before* coalesce_dimensions() is called.
  408. DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;
  409. /// Helper functions for CPU iteration
  410. StrideVector get_dim_strides(int dim) const;
  411. StrideVector get_strides() const;
  412. StrideVector get_inner_strides() const {
  413. return get_dim_strides(0);
  414. }
  415. PtrVector get_base_ptrs() const;
  416. // Helper functions for advanced stride manipulations (e.g. torch.flip)
  417. void _unsafe_set_arg_strides(const int64_t arg, IntArrayRef strides) {
  418. operands_[arg].stride_bytes = strides;
  419. }
  420. void _unsafe_set_arg_data(const int64_t arg, void* data) {
  421. operands_[arg].data = data;
  422. }
  423. // Helper functions for custom device, custom device can get OperandInfo and
  424. // NameVector in their side.
  425. const OperandInfo& operand(int arg = 0) const {
  426. return operands_[arg];
  427. }
  428. OperandInfo& operand(int arg = 0) {
  429. return operands_[arg];
  430. }
  431. NameVector& get_dim_names() {
  432. return names_;
  433. }
  434. const NameVector& get_dim_names() const {
  435. return names_;
  436. }
  437. /// true if the stride computation can use 32-bit arithmetic. Used by GPU
  438. /// kernels
  439. bool can_use_32bit_indexing() const;
  440. /// An "iteratable" object that recursively splits this iterator into
  441. /// sub-iterators that can use 32-bit indexing.
  442. SplitUntil32Bit with_32bit_indexing() const;
  443. /// If the kernel should accumulate into the output. Only relevant for CUDA
  444. /// reductions.
  445. bool should_accumulate() const {
  446. return accumulate_;
  447. }
  448. /// Whether this iterator produces the actual output,
  449. /// as opposed to something that will be accumulated further. Only relevant
  450. /// for CUDA reductions.
  451. bool is_final_output() const {
  452. return final_output_;
  453. }
  454. bool has_contiguous_first_dim() const {
  455. if (ndim() == 0) {
  456. return true;
  457. }
  458. int num_tensors = ntensors();
  459. for (const auto i : c10::irange(num_tensors)) {
  460. if (strides(i)[0] != element_size(i)) {
  461. return false;
  462. }
  463. }
  464. return true;
  465. }
  466. void set_output_raw_strided(
  467. int64_t output_idx,
  468. IntArrayRef sizes,
  469. IntArrayRef strides,
  470. TensorOptions options,
  471. DimnameList names) override;
  472. #define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \
  473. maybestatic void methodname( \
  474. TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \
  475. maybestatic void methodname( \
  476. const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \
  477. maybestatic void methodname( \
  478. const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \
  479. maybestatic void methodname( \
  480. TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \
  481. maybestatic void methodname( \
  482. TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \
  483. maybestatic void methodname( \
  484. const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \
  485. maybestatic void methodname( \
  486. TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete;
  487. #define TORCH_DISALLOW_TEMPORARIES(methodname) \
  488. TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, )
  489. void build_binary_float_op(
  490. const TensorBase& out,
  491. const TensorBase& a,
  492. const TensorBase& b);
  493. void build_borrowing_binary_float_op(
  494. const TensorBase& out,
  495. const TensorBase& a,
  496. const TensorBase& b);
  497. TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op)
  498. void build_binary_op(
  499. const TensorBase& out,
  500. const TensorBase& a,
  501. const TensorBase& b);
  502. void build_borrowing_binary_op(
  503. const TensorBase& out,
  504. const TensorBase& a,
  505. const TensorBase& b);
  506. TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
  507. void build_unary_float_op(const TensorBase& out, const TensorBase& a);
  508. void build_borrowing_unary_float_op(
  509. const TensorBase& out,
  510. const TensorBase& a);
  511. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
  512. void build_unary_op(const TensorBase& out, const TensorBase& a);
  513. // Odd special case needed for pow. Has to borrow the output because
  514. // it's a structured kernel, but the argument is potentially a copy.
  515. void build_output_borrowing_argument_owning_unary_op(
  516. const TensorBase& out,
  517. const TensorBase& a);
  518. void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
  519. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
  520. void build_borrowing_unary_force_boolean_op(
  521. const TensorBase& out,
  522. const TensorBase& a);
  523. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
  524. void build_comparison_op(
  525. const TensorBase& out,
  526. const TensorBase& a,
  527. const TensorBase& b);
  528. void build_borrowing_comparison_op(
  529. const TensorBase& out,
  530. const TensorBase& a,
  531. const TensorBase& b);
  532. TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
  533. // Another special case: we need to own the second argument for comparison
  534. // ops.
  535. void build_borrowing_except_last_argument_comparison_op(
  536. const TensorBase& out,
  537. const TensorBase& a,
  538. const TensorBase& b);
  539. void build_ternary_op(
  540. const TensorBase& out,
  541. const TensorBase& a,
  542. const TensorBase& b,
  543. const TensorBase& c);
  544. #undef TORCH_DISALLOW_TEMPORARIES
  545. protected:
  546. // Mutable reference as it moves tensors out of TensorIteratorConfig
  547. void populate_operands(TensorIteratorConfig&);
  548. void mark_outputs();
  549. void mark_resize_outputs(const TensorIteratorConfig&);
  550. void compute_mem_overlaps(const TensorIteratorConfig&);
  551. void compute_shape(const TensorIteratorConfig&);
  552. void compute_strides(const TensorIteratorConfig&);
  553. void reorder_dimensions();
  554. void permute_dimensions(IntArrayRef perm);
  555. void compute_types(const TensorIteratorConfig&);
  556. ScalarType compute_common_dtype();
  557. void allocate_or_resize_outputs();
  558. bool fast_set_up(const TensorIteratorConfig&);
  559. FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
  560. void compute_names(const TensorIteratorConfig&);
  561. void propagate_names_to_outputs();
  562. void coalesce_dimensions();
  563. protected:
  564. /// Records the "computation" shape of the output tensor. The computation
  565. /// shape is different from the regular shape in a few ways:
  566. ///
  567. /// - The shape may be permuted (via permute_dimensions) so that we
  568. /// process the dimensions in the most computationally efficient order
  569. /// (rather than the logical order given to us by the users.)
  570. /// - The shape may have adjacent dimensions collapsed (via
  571. /// coalesce_dimensions) so that we minimize the number of
  572. /// dimensions we have to explicitly iterate over. For example,
  573. /// a pointwise operation on a contiguous tensor "computationally"
  574. /// consists of only a single dimension.
  575. ///
  576. /// In other words, the computation shape is the output shape as it
  577. /// actually matters for implementing the kernel, but not necessarily the
  578. /// output shape that the user will see in the end.
  579. ///
  580. /// The lifecycle of mutations to shape_ in TensorIterator:
  581. /// - declare_static_shape() sets an initial shape explicitly
  582. /// provided by user, otherwise
  583. /// - compute_shape() computes the true (non-computational) shape
  584. /// specified by the user.
  585. /// - reorder_dimensions() reorders dimensions to improve coalescing.
  586. /// - coalesce_dimensions() then coalesces adjacent dimensions when
  587. /// possible.
  588. ///
  589. /// The shape may also be further modified if we create sub-TensorIterators,
  590. /// e.g., via narrow or select_all_keeping_dim.
  591. DimVector shape_;
  592. /// Temporarily records the permutation computed by reorder_dimensions.
  593. /// This permutation maps the computation output dimension (dim) to
  594. /// the original true output dimension (perm_[dim]). It is used by
  595. /// invert_perm to undo the permutation. After coalesce_dimensions is
  596. /// called, the permutation is no longer valid (as, in general, there
  597. /// is no permutation that will make computation dimensions to
  598. /// output dimensions); methods that manipulate perm_ are obligated
  599. /// to test that !has_coalesced_dimensions
  600. DimVector perm_;
  601. /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build())
  602. /// been called? This is SOLELY used to check validity of perm_.
  603. bool has_coalesced_dimensions_ = false;
  604. /// Whether iteration must be fixed. This disables dimension permuting and
  605. /// also changes how for_each divides work among threads.
  606. bool enforce_linear_iteration_ = false;
  607. /// The index offsets into the original tensors for each dimension.
  608. /// This is only non-zero when you narrow() a TensorIterator (e.g.,
  609. /// when you make sub-TensorIterators).
  610. DimVector view_offsets_;
  611. /// The computed names of the output tensor. Computed by compute_names()
  612. NameVector names_;
  613. /// The operands of the TensorIterator: both the inputs and outputs. The
  614. /// outputs MUST come first in the operands_ list. There is always an
  615. /// operand for each output of the TensorIterator, even if TensorIterator
  616. /// will ultimately be responsible for allocating the output; in those
  617. /// cases, tensor is simply undefined (and will be populated later
  618. /// during build()).
  619. ///
  620. /// This list is initially populated prior to build(), but build() mutates
  621. /// OperandInfo to populate more information.
  622. SmallVector<OperandInfo, 4> operands_;
  623. /// Number of outputs in operands_ (the length of the outputs prefix
  624. /// in operands_).
  625. int num_outputs_ = 0;
  626. /// Whether or not all operands have the same shape and are 1d+. Having all
  627. /// the same shape affects whether or not the iterator is eligible for fast
  628. /// setup.
  629. bool all_ops_same_shape_ = false;
  630. /// Whether or not all operands are 0d, this affects type promotion
  631. bool all_ops_are_scalars_ = false;
  632. /// The "computation" dtype of TensorIterator, specifying what the dtype
  633. /// we will do the internal computation in TensorIterator. Typically,
  634. /// this matches the dtype of the output tensors, but not always!
  635. ScalarType common_dtype_ = ScalarType::Undefined;
  636. /// This is currently defined as kCPU, or the device of the first non-CPU
  637. /// tensor argument. See TensorIteratorBase::compute_types for details.
  638. Device common_device_ = kCPU;
  639. /// Set by split(), see should_accumulate() and is_final_output()
  640. bool accumulate_ = false;
  641. bool final_output_ = true;
  642. // From TensorIteratorConfig
  643. bool is_reduction_ = false;
  644. /// Set by populate_operands(), says if we're handling meta tensors
  645. bool is_meta_ = false;
  646. };
  647. struct TORCH_API TensorIterator final : public TensorIteratorBase {
  648. TensorIterator() : TensorIteratorBase() {}
  649. // Slicing is OK, TensorIterator guaranteed NOT to have any fields
  650. TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}
  651. #define TORCH_DISALLOW_TEMPORARIES(methodname) \
  652. TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)
  653. static TensorIterator binary_float_op(
  654. TensorBase& out,
  655. const TensorBase& a,
  656. const TensorBase& b);
  657. static TensorIterator binary_op(
  658. TensorBase& out,
  659. const TensorBase& a,
  660. const TensorBase& b);
  661. static TensorIterator borrowing_binary_op(
  662. const TensorBase& out,
  663. const TensorBase& a,
  664. const TensorBase& b);
  665. TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
  666. static TensorIterator comparison_op(
  667. TensorBase& out,
  668. const TensorBase& a,
  669. const TensorBase& b);
  670. static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
  671. static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
  672. static TensorIterator nullary_op(TensorBase& out);
  673. static TensorIterator borrowing_nullary_op(const TensorBase& out);
  674. static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
  675. static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
  676. static TensorIterator reduce_op(
  677. TensorBase& out1,
  678. TensorBase& out2,
  679. const TensorBase& a);
  680. #undef TORCH_DISALLOW_TEMPORARIES
  681. #undef TORCH_DISALLOW_TEMPORARIES_IMPL
  682. const Tensor& maybe_get_output(int64_t output_idx) override;
  683. void set_output_raw_strided(
  684. int64_t output_idx,
  685. IntArrayRef sizes,
  686. IntArrayRef strides,
  687. TensorOptions options,
  688. DimnameList names) override;
  689. };
  690. class TORCH_API TensorIteratorConfig final {
  691. public:
  692. friend struct TensorIteratorBase;
  693. friend struct TensorIterator;
  694. TensorIteratorConfig() = default;
  695. C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig);
  696. /// Construction
  697. // Stores input/output Tensors without incrementing the reference count.
  698. // Important: the outputs have to be added before the inputs.
  699. TensorIteratorConfig& add_output(const TensorBase& output) {
  700. return add_borrowed_output(output);
  701. }
  702. TensorIteratorConfig& add_input(const TensorBase& input) {
  703. return add_borrowed_input(input);
  704. }
  705. TensorIteratorConfig& add_const_input(const TensorBase& input) {
  706. return add_borrowed_const_input(input);
  707. }
  708. // Borrowing from temporaries is unlikely to go well.
  709. TensorIteratorConfig& add_output(TensorBase&& output) = delete;
  710. TensorIteratorConfig& add_input(TensorBase&& input) = delete;
  711. TensorIteratorConfig& add_const_input(TensorBase&& input) = delete;
  712. // Stores input/output Tensors while incrementing the reference count.
  713. // Note that add_{in,out}put are nearly always what you
  714. // want, and the exception (adding an unnamed temporary) won't
  715. // compile.
  716. TensorIteratorConfig& add_owned_output(const TensorBase& output);
  717. TensorIteratorConfig& add_owned_input(const TensorBase& input);
  718. TensorIteratorConfig& add_owned_const_input(const TensorBase& input);
  719. // Advanced API: stores input/output Tensors without incrementing
  720. // the reference count. The caller must ensure that these Tensors
  721. // live at least as long as this TensorIteratorConfig and any
  722. // TensorIteratorBase built from this TensorIteratorConfig.
  723. // Important: the outputs have to be added before the inputs.
  724. TensorIteratorConfig& add_borrowed_output(const TensorBase& output);
  725. TensorIteratorConfig& add_borrowed_input(const TensorBase& input);
  726. TensorIteratorConfig& add_borrowed_const_input(const TensorBase& input);
  727. // Borrowing from temporaries is unlikely to go well.
  728. TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete;
  729. TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete;
  730. TensorIteratorConfig& add_borrowed_const_input(TensorBase&& input) = delete;
  731. // Sets the check_mem_overlap_ flag, which is true by default.
  732. // If true, inputs are checked for partial overlap with the outputs and
  733. // outputs are checked for internal overlap (e.g. broadcasted views). An error
  734. // is raised if unacceptable overlap is detected.
  735. // If you're migrating an existing operator to using TensorIterator, please
  736. // consider if the previous implementation checked memory overlap. If it did
  737. // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
  738. // checking memory overlap is BC-breaking. Please don't check memory overlap
  739. // in that case.
  740. TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
  741. check_mem_overlap_ = check_mem_overlap;
  742. return *this;
  743. }
  744. // Sets the check_all_same_dtype_ flag, which is true by default
  745. // If true, checks that all inputs and defined outputs have the same dtype
  746. // Setting either of promote_inputs_to_common_dtype_
  747. // or cast_common_dtype_to_outputs_ to true will set
  748. // check_all_same_dtype_ to false.
  749. TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
  750. check_all_same_dtype_ = _check_all_same_dtype;
  751. return *this;
  752. }
  753. // Sets the check_all_same_device_ flag, which is true by default
  754. // If true, all operands must be on the same device, with the possible
  755. // exception of CPU scalars, which can be passed to some CUDA kernels
  756. // as kernel arguments.
  757. TensorIteratorConfig& check_all_same_device(
  758. const bool _check_all_same_device) {
  759. check_all_same_device_ = _check_all_same_device;
  760. return *this;
  761. }
  762. // Sets the enforce_safe_casting_to_output_ flag, which is false by default
  763. // If true, the iterator's "common dtype" must be computable
  764. // (see the [Common Dtype Computation] note) and
  765. // canCast(common dtype, output dtype) must be true for all outputs.
  766. TensorIteratorConfig& enforce_safe_casting_to_output(
  767. const bool _enforce_safe_casting_to_output) {
  768. enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
  769. return *this;
  770. }
  771. // Sets the enforce_linear_iteration_ flag, which is false by default.
  772. // If true, iteration goes in the same order as a C-contiguous tensor
  773. // is layed out in memory. i.e. last dimension iterates fastest.
  774. //
  775. // This iteration order can be less efficient and may even prevent
  776. // vectorization. So only use if the correctness of your kernel depends on it.
  777. TensorIteratorConfig& enforce_linear_iteration(
  778. const bool _enforce_linear_iteration = true) {
  779. enforce_linear_iteration_ = _enforce_linear_iteration;
  780. return *this;
  781. }
  782. // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
  783. // If true, the iterator's "common dtype" is always computed (see the
  784. // [Common Dtype Computation] note) and, on the CPU, temporary copies of
  785. // the inputs in the common dtype are passed as the actual inputs to
  786. // the operation.
  787. // Setting this flag to true sets check_all_same_dtype_ to false.
  788. TensorIteratorConfig& promote_inputs_to_common_dtype(
  789. const bool _promote_inputs_to_common_dtype) {
  790. promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype;
  791. if (_promote_inputs_to_common_dtype) {
  792. check_all_same_dtype_ = false;
  793. }
  794. return *this;
  795. }
  796. // Sets the promote_integer_inputs_to_float_ flag, which is false by default
  797. // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be
  798. // true. If true, if the iterator's "common dtype" is an integral type
  799. // (including bool)
  800. // then it is changed to the default float scalar type.
  801. TensorIteratorConfig& promote_integer_inputs_to_float(
  802. const bool _promote_integer_inputs_to_float) {
  803. promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float;
  804. TORCH_INTERNAL_ASSERT(
  805. !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_);
  806. return *this;
  807. }
  808. TensorIteratorConfig& is_reduction(const bool _is_reduction) {
  809. is_reduction_ = _is_reduction;
  810. return *this;
  811. }
  812. TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) {
  813. allow_cpu_scalars_ = _allow_cpu_scalars;
  814. return *this;
  815. }
  816. // Sets the cast_common_dtype_to_outputs_ flag, which is false by default
  817. // If true, the iterator's "common dtype" must be computatable
  818. // (see the [Common Dtype Computation] note) and, on the CPU, temporary
  819. // copies of the outputs are passed as the actual output to the operation.
  820. // These temporaries are then copied to the original outputs after
  821. // the operation is performed (see cast_outputs()).
  822. // Setting this flag to true sets check_all_same_dtype_ to false.
  823. TensorIteratorConfig& cast_common_dtype_to_outputs(
  824. const bool _cast_common_dtype_to_outputs) {
  825. cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs;
  826. if (_cast_common_dtype_to_outputs) {
  827. check_all_same_dtype_ = false;
  828. }
  829. return *this;
  830. }
  831. TensorIteratorConfig& resize_outputs(bool resize_outputs) {
  832. resize_outputs_ = resize_outputs;
  833. return *this;
  834. }
  835. // Bypass output dtype/device computation and fix the dtype/device as
  836. // specified here.
  837. TensorIteratorConfig& declare_static_dtype_and_device(
  838. ScalarType dtype,
  839. Device device);
  840. TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
  841. TensorIteratorConfig& declare_static_device(Device device);
  842. TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
  843. TensorIteratorConfig& declare_static_shape(
  844. IntArrayRef shape,
  845. IntArrayRef squash_dims);
  846. // It would be better if this was && qualified, but this would be at the cost
  847. // of a lot of boilerplate above
  848. TensorIterator build() {
  849. TensorIterator iter;
  850. iter.build(*this);
  851. return iter;
  852. }
  853. private:
  854. bool is_tensor_const(size_t idx);
  855. SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_;
  856. int num_outputs_ = 0;
  857. int num_inputs_ = 0;
  858. std::optional<DimVector> static_shape_ = c10::nullopt;
  859. std::optional<ScalarType> static_dtype_ = c10::nullopt;
  860. std::optional<Device> static_device_ = c10::nullopt;
  861. bool check_mem_overlap_ = true;
  862. bool allow_cpu_scalars_ = false;
  863. bool is_reduction_ = false;
  864. bool resize_outputs_ = true;
  865. bool check_all_same_dtype_ = true;
  866. bool check_all_same_device_ = true;
  867. bool enforce_safe_casting_to_output_ = false;
  868. bool enforce_linear_iteration_ = false;
  869. bool promote_inputs_to_common_dtype_ = false;
  870. bool promote_integer_inputs_to_float_ = false;
  871. bool cast_common_dtype_to_outputs_ = false;
  872. SmallVector<size_t, 4> const_tensor_indices_;
  873. };
  874. /// A container-like struct that acts as if it contains splits of a
  875. /// TensorIterator that can use 32-bit indexing. Taken together the splits cover
  876. /// the original TensorIterator.
  877. struct TORCH_API SplitUntil32Bit {
  878. struct TORCH_API iterator {
  879. iterator() = default;
  880. iterator(const TensorIteratorBase& iter);
  881. iterator(iterator&&) = default;
  882. // Guaranteed to be a TensorIterator proper!
  883. TensorIterator& operator*() const;
  884. iterator& operator++();
  885. bool operator==(const iterator& other) const {
  886. // two iterators are equal if they are the same object or they're both
  887. // empty
  888. return this == &other || (vec.empty() && other.vec.empty());
  889. }
  890. // needed for C++11 range-based for loop
  891. bool operator!=(const iterator& other) const {
  892. return !(*this == other);
  893. }
  894. /// stack of TensorIterators to be split
  895. std::vector<std::unique_ptr<TensorIterator>> vec;
  896. };
  897. SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {}
  898. iterator begin() const;
  899. iterator end() const;
  900. private:
  901. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  902. const TensorIteratorBase& iter;
  903. };
  904. } // namespace at