LegacyBatchedTensorImpl.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #pragma once
  2. #include <bitset>
  3. #include <ATen/ArrayRef.h>
  4. #include <ATen/SmallVector.h>
  5. #include <ATen/Tensor.h>
  6. namespace at {
  7. // We assume this in a few other places in the codebase,
  8. // but there isn't a centralized definition.
  9. constexpr int64_t kVmapMaxTensorDims = 64;
  10. // The valid vmap levels range from [0, 64). This effectively means that we
  11. // support a maximum of 64 nested vmaps.
  12. constexpr int64_t kVmapNumLevels = 64;
  13. // Store this number of elements of BatchDims on the stack. Most people will
  14. // probably use <= 5 nested vmaps, but adjust this number as necessary.
  15. constexpr int64_t kBatchDimsStackSize = 5;
  16. // a BatchDim represents a "private" dimension on a Tensor created inside of
  17. // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
  18. // is being vmap'ed over and the `level` being an identifier for which vmap
  19. // said dimension was created inside. The `dim` corresponds to a "physical
  20. // dim" - it is a dimension index on the underlying physical tensor that is
  21. // being vmapped over.
  22. struct BatchDim {
  23. BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
  24. int64_t dim() const {
  25. return dim_;
  26. }
  27. int64_t level() const {
  28. return level_;
  29. }
  30. private:
  31. int64_t dim_;
  32. int64_t level_;
  33. };
  34. using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
  35. using BatchDimsRef = ArrayRef<BatchDim>;
  36. // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
  37. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  38. // BatchedTensorImpl.
  39. //
  40. // The batch dimensions are treated as being "private"; they are not
  41. // user-visible. For example, in the following Tensor,
  42. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
  43. // dimensions 0 and 1 are batch dimensions.
  44. //
  45. // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
  46. // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
  47. // tensor.
  48. struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
  49. explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
  50. // Returns a reference to BatchDims that represent which dimensions of this
  51. // tensor are private.
  52. BatchDimsRef bdims() const {
  53. return bdims_;
  54. }
  55. // BatchedTensorImpl wraps a Tensor
  56. const Tensor& value() const {
  57. return value_;
  58. };
  59. // Given a public dimension index, return the dimension index in the
  60. // underlying value() tensor. For example, if we have
  61. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
  62. // dim=2)])
  63. // bt.actualDim(0) -> 1
  64. // bt.actualDim(1) -> 3
  65. // bt.actualDim(2) -> Error
  66. int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
  67. // We have to override this because we opted into CustomStrides
  68. IntArrayRef strides_custom() const override;
  69. // Override a bunch of methods inherited from TensorImpl to return error
  70. // messages.
  71. bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
  72. void set_size(int64_t dim, int64_t new_size) override;
  73. void set_stride(int64_t dim, int64_t new_stride) override;
  74. void set_storage_offset(int64_t storage_offset) override;
  75. #ifdef DEBUG
  76. bool has_storage() const override;
  77. #endif
  78. private:
  79. // see NOTE: [BatchedTensorImpl levels invariant]
  80. void checkInvariants() const;
  81. const char* tensorimpl_type_name() const override;
  82. Tensor value_;
  83. // Note: [BatchedTensorImpl levels invariant]
  84. // There is an invariant that the BatchDims must be stored in increasing
  85. // `level` order. That is, for i < j, bdims_[i].level must be less than
  86. // bdims_[j].level.
  87. BatchDims bdims_;
  88. };
  89. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  90. // BatchedTensorImpl.
  91. inline bool isBatchedTensor(const Tensor& tensor) {
  92. return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
  93. }
  94. // It is unsafe to call this on a Tensor that is not backed by a
  95. // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
  96. inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
  97. return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
  98. }
  99. inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
  100. if (!isBatchedTensor(tensor)) {
  101. return nullptr;
  102. }
  103. return unsafeGetBatchedImpl(tensor);
  104. }
  105. // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
  106. inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
  107. BatchDimsRef bdims) {
  108. std::bitset<kVmapMaxTensorDims> is_bdim;
  109. for (const auto& bdim : bdims) {
  110. is_bdim.set(bdim.dim());
  111. }
  112. return is_bdim;
  113. }
  114. // Creates a bitset for all of the levels present in `bdims`
  115. inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
  116. std::bitset<kVmapNumLevels> result;
  117. for (const auto& bdim : bdims) {
  118. result.set(bdim.level());
  119. }
  120. return result;
  121. }
  122. inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
  123. out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
  124. return out;
  125. }
  126. // Use this to construct a BatchedTensor from a regular Tensor
  127. TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
  128. // Adds a batch dim to `tensor`, returning a BatchedTensor
  129. TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
  130. // Checks if an inplace operation on self and other is "vmap compatible".
  131. // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
  132. TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
  133. } // namespace at