FunctionalStorageImpl.h 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. namespace at::functionalization {
  4. // See Note [Functionalization Pass In Core]
  5. // ViewMeta is a class used by the functionalization pass to navigate between
  6. // a base tensor and a view tensor.
  7. // For example, if I call `b = a.view1(...)`
  8. // the functionalization pass will generate and store a ViewMeta on b that looks
  9. // like:
  10. //
  11. // ViewMeta(
  12. // [<captures>](const Tensor& base, int64_t mutated_view_idx) {
  13. // return base.view1(...);
  14. // },
  15. // [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
  16. // int64_t mutated_view_idx) -> at::Tensor {
  17. // return at::functionalization::impl::view1_inverse(base, mutated_view,
  18. // ...);
  19. // }
  20. //
  21. // The forward_fn lambda describes how to replay view1 on a tensor.
  22. //
  23. // The reverse_fn lambda describes how, given a tensor that is already a view,
  24. // how to get the corresponding base tensor. See Note [Functionalization Pass:
  25. // View Inverses] for details.
  26. struct ViewMeta {
  27. ViewMeta(
  28. std::function<Tensor(const Tensor&, int64_t)> forward,
  29. std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
  30. bool has_symbolic_inputs,
  31. bool is_multi_output = false,
  32. bool is_as_strided = false,
  33. int64_t out_idx = 0)
  34. : forward_fn(std::move(forward)),
  35. reverse_fn(std::move(reverse)),
  36. out_index(out_idx),
  37. is_multi_output(is_multi_output),
  38. is_as_strided(is_as_strided),
  39. has_symbolic_inputs(has_symbolic_inputs) {}
  40. std::function<Tensor(const Tensor&, int64_t)> forward_fn;
  41. std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
  42. // See Note [out_idx in ViewMeta]
  43. int64_t out_index;
  44. // Tells us if this is a multi-output view
  45. bool is_multi_output;
  46. bool is_as_strided;
  47. // Tells us if this view operation has any symbolic inputs
  48. bool has_symbolic_inputs;
  49. // Returns a copy of the current ViewMeta, if out_idx matches the current
  50. // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
  51. // functions, but a new out index.
  52. ViewMeta to_out_idx(int64_t out_idx);
  53. };
  54. // FunctionalStorageImpl is a subclass of StorageImpl used by the
  55. // functionalization pass. It has no underlying data (similar to meta storage).
  56. // It also knows how to reflect mutations to tensors in the absence of a valid
  57. // data pointer.
  58. //
  59. // A storage represents the state shared by (potentially multiple) views of the
  60. // same tensor. For example, in the following code:
  61. //
  62. // b = a.view1(...)
  63. // c = b.view2(...)
  64. // b.add_(1)
  65. // --> storage.add_update(b, {view1_meta})
  66. //
  67. // The call to add_(1) will result in a call to alias.add_update(b,
  68. // {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
  69. // c is used in an expression (e.g. you try to print c, or pass it to an
  70. // operator). Doing so will involve "syncing" c. First we apply any pending
  71. // updates to the alias, and then we regenerate c by replaying its views off of
  72. // the updated alias. E.g:
  73. //
  74. // print(str(c))
  75. // --> c.sync_()
  76. // --> alias.apply_updates() // after this, the alias will be updated to
  77. // reflect the mutation to b
  78. struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
  79. public:
  80. struct Update {
  81. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  82. const at::Tensor new_val;
  83. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  84. const std::vector<ViewMeta> view_metas;
  85. };
  86. explicit FunctionalStorageImpl(const Tensor& value);
  87. void add_update(
  88. const Tensor& updated_val,
  89. const std::vector<ViewMeta>& view_metas);
  90. bool apply_updates();
  91. const Tensor& base() {
  92. return base_;
  93. }
  94. size_t generation() const {
  95. return generation_;
  96. }
  97. void freeze() {
  98. frozen_ = true;
  99. }
  100. c10::SymInt get_storage_size(bool before) {
  101. if (before) {
  102. return original_storage_size_;
  103. } else {
  104. return curr_storage_size_;
  105. }
  106. }
  107. ~FunctionalStorageImpl() override = default;
  108. void mark_mutation() {
  109. mutation_counter_++;
  110. }
  111. void mark_mutation_during_no_grad_or_inference_mode() {
  112. mutation_counter_during_no_grad_or_inference_mode_++;
  113. }
  114. void mark_mutation_hidden_from_autograd() {
  115. mutation_counter_hidden_from_autograd_++;
  116. }
  117. bool are_all_mutations_under_no_grad_or_inference_mode() const {
  118. auto non_autograd_mutations =
  119. mutation_counter_during_no_grad_or_inference_mode_ +
  120. mutation_counter_hidden_from_autograd_;
  121. // The <= is because both counters will technically be incremented, if we
  122. // perform e.g. a triton kernel mutation under no_grad
  123. return mutation_counter_ <= non_autograd_mutations;
  124. }
  125. bool are_all_mutations_hidden_from_autograd() const {
  126. // mutations under no_grad / inference_mode are technically not hidden from
  127. // autograd - they change the version counter
  128. return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
  129. }
  130. void mark_inductor_storage_resize(c10::SymInt new_size) {
  131. inductor_storage_resized_ = true;
  132. curr_storage_size_ = new_size;
  133. }
  134. bool was_inductor_storage_resized() {
  135. return inductor_storage_resized_;
  136. }
  137. private:
  138. // NB: base_ should always point to a tensor BELOW the current
  139. // functionalization layer. This is mainly to avoid reference cycles. e.g.
  140. // given `b = a.view(...)` Both a.storage_ and b.storage_ are a
  141. // FunctionStorageImpl containing an Walualias, with contains a Tensor
  142. // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
  143. // should point not to a, but to a's unwrapped value, a.value_` See Note
  144. // [Functionalization: Walualias Removal] for a diagram that shows this
  145. // visually.
  146. at::Tensor base_;
  147. std::vector<Update> updates_;
  148. // generation_ gets incremented every time a mutation is queued onto the
  149. // alias. It is used to determine if a given tensor is "up to date", or if it
  150. // needs to be regenerated from the alias.
  151. size_t generation_ = 0;
  152. // If frozen, no more mutations are allowed on this storage. Once frozen, a
  153. // storage cannot be unfrozen.
  154. bool frozen_ = false;
  155. // These mutation counters are bumped on the storage
  156. // whenever a FunctionalTensorWrapper experiences a mutation.
  157. // When the mutation is under no_grad, or comes from a triton kernel, we also
  158. // bump the corresponding during_no_grad or hidden_from_autograd counters. Why
  159. // do we need to detect these two situations separately from "normal" input
  160. // mutations? (1) "normal" input mutations can mutate autograd metadata like
  161. // .grad_fn,
  162. // in which case they need to be replayed outside of the compiled graph
  163. // (2) "no_grad" input mutations are generally safe to keep in the graph (and
  164. // compile),
  165. // but they bump the tensor's VC, so we need to mark_dirty() on the inputs
  166. // in torch.compile
  167. // (3) mutations that are fully hidden from autograd (e.g. from a triton
  168. // kernel)
  169. // do not mutate any autograd state, and be fully kept in the graph
  170. // When we detect that an input was mutated, we need to be able to tell if:
  171. // (1) all of the mutations were from triton kernels
  172. // (2) all of the mutations were under no_grad
  173. uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
  174. uint64_t mutation_counter_ = 0;
  175. uint64_t mutation_counter_hidden_from_autograd_ = 0;
  176. // Used to tell if:
  177. // (1) There were any storage resizes on a graph input
  178. // (2) The original/curr storage size tell us if these resizes result in a nop
  179. bool inductor_storage_resized_ = false;
  180. c10::SymInt original_storage_size_;
  181. c10::SymInt curr_storage_size_;
  182. };
  183. } // namespace at::functionalization