FunctionalTensorWrapper.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. #pragma once
  2. #include <ATen/ArrayRef.h>
  3. #include <ATen/FunctionalStorageImpl.h>
  4. #include <ATen/core/IListRef.h>
  5. #include <ATen/core/List.h>
  6. #include <ATen/core/boxing/BoxedKernel.h>
  7. #include <ATen/core/boxing/impl/boxing.h>
  8. #include <ATen/core/dispatch/Dispatcher.h>
  9. #include <c10/core/DispatchKey.h>
  10. namespace at {
  11. // Note [Functionalization Pass In Core]
  12. // The Functionalization pass is used to remove aliasing from a pytorch program.
  13. //
  14. // This is useful for backends that don't support aliasing, like XLA and Vulkan.
  15. // It's also necessary in order to remove mutation from a program, which is
  16. // needed in Functorch.
  17. //
  18. // Consider this program:
  19. // a = torch.ones(...)
  20. // b = a.view(...)
  21. // b.add_(1)
  22. //
  23. // In this program, b is meant to alias with a due to the use of view(). At the
  24. // end of the program, both a and b are full of 2's. However, backends that
  25. // don't support aliasing aren't able to correctly implement the view()
  26. // operator. Instead, they can opt into the Functionalization pass, which will
  27. // sit between the user and the backend, and provide the necessary aliasing
  28. // logic.
  29. //
  30. // The functionalization pass will turn the above program into a slightly
  31. // different program that has the same semantics, transparently to the user,
  32. // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
  33. // a.view_copy(...) # view() replaced with view_copy(). Backends like
  34. // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
  35. // pass machinery knows that a and b are aliased - it applies b's mutation to a
  36. // too.
  37. //
  38. // So, how does the functionalization pass keep track of which tensors are
  39. // aliased? The pass works by wrapping EVERY tensor in the program inside of a
  40. // FunctionalTensorWrapper, which knows about its alias'd tensors.
  41. //
  42. // See Note [Functionalization: Alias Removal] for details on the aliasing
  43. // machinery. See Note [Functionalization: Mutation Removal] for details on
  44. // mutation removal.
  45. struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
  46. explicit FunctionalTensorWrapper(const Tensor& value);
  47. // Additional constructor to create a FunctionalTensorWrapper directly from an
  48. // underlying tensor that was created from a view. For example, the code b =
  49. // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
  50. // view1_meta)
  51. explicit FunctionalTensorWrapper(
  52. const Tensor& view_value,
  53. const FunctionalTensorWrapper* base,
  54. const functionalization::ViewMeta& meta);
  55. // Get the underlying, actual tensor, that doesn't know anything about
  56. // functionalization.
  57. const Tensor& value() const {
  58. return value_;
  59. };
  60. // The concept of "level" is only ever important to functorch; it's exposed
  61. // here as more of a hook for functorch to use.
  62. int64_t level() const {
  63. return level_;
  64. };
  65. void set_level(int64_t level) {
  66. level_ = level;
  67. }
  68. bool has_metadata_mutation() const {
  69. return has_metadata_mutation_;
  70. };
  71. void mark_mutation() {
  72. functional_storage_impl()->mark_mutation();
  73. }
  74. // Denotes a mutation that's hidden from autograd,
  75. // e.g. for the purposes of passing a tensor to a triton kernel
  76. void mark_mutation_hidden_from_autograd() {
  77. functional_storage_impl()->mark_mutation_hidden_from_autograd();
  78. }
  79. void mark_mutation_during_no_grad_or_inference_mode() {
  80. functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
  81. }
  82. // Are all the mutations happening to the tensor hidden from autograd
  83. bool are_all_mutations_hidden_from_autograd() const {
  84. return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
  85. }
  86. // Did all mutations happen under no_grad or inference_mode
  87. // (We also need to ignore mutations fully hidden from autograd here)
  88. bool are_all_mutations_under_no_grad_or_inference_mode() const {
  89. return functional_storage_impl()
  90. ->are_all_mutations_under_no_grad_or_inference_mode();
  91. }
  92. void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
  93. is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
  94. }
  95. bool is_symbolic() const {
  96. return is_symbolic_;
  97. }
  98. // Runs the forward_fn of every ViewMeta collected in the current instance
  99. // to some other base.
  100. Tensor apply_view_metas(const Tensor& base);
  101. // Sync's the underlying tensor with its alias, if it's out of date. This
  102. // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
  103. // Replay the views (if any) to regenerate the current tensor off of the
  104. // updated alias.
  105. void sync_();
  106. // Performs step (1) of the sync. This is its own public API because it's
  107. // needed by view_inplace ops like transpose_. See Note [Functionalization
  108. // Pass - Inplace View Ops]
  109. void regenerate_from_base();
  110. // Performs step (2) of the sync. This is its own public API because it's
  111. // needed by functorch. functorch wants to make sure that all input tensors to
  112. // a functionalized program have been properly synced so it can properly
  113. // propagate mutations to inputs. It can't just call sync_(), because the
  114. // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
  115. // a noop. We use the reference count on storage_ to determine if the wrapper
  116. // is aliased, and by the time functorch is ready to propagate updates to
  117. // inputs, any intermediate views of the input created by the program will
  118. // have been deallocated. This function also returns whether or not the base
  119. // actually had any updates to apply.
  120. bool apply_updates();
  121. // Takes the current state of value_ and snapshots it, sending it as a pending
  122. // update to the alias.
  123. void commit_update();
  124. // When any tensor is mutated, the tensor increments its alias's "generation".
  125. // Separately, each tensor maintains its own "generation" counter, which is
  126. // used to determine if it's up-to-date with its alias. The act of syncing a
  127. // tensor will set a tensor's generation equal to its alias's generation.
  128. bool is_up_to_date() const;
  129. // Freezes the storage of this tensor, preventing subsequent mutations
  130. void freeze_storage() const;
  131. // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
  132. // describing the series of view ops that ran to generate the current tensor
  133. // from the base tensor. This method is used by inplace-view ops like
  134. // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
  135. // tensor by replaying the views off of the alias.
  136. void mutate_view_meta(const at::functionalization::ViewMeta& meta);
  137. // Custom implementation of self.set_(src)
  138. void set__impl(const FunctionalTensorWrapper* other);
  139. // Custom implementation of resize_storage_bytes_(self, new_size)
  140. void storage_resize_(c10::SymInt new_size);
  141. // Returns whether the current tensor's data was ever mutated
  142. bool has_data_mutation();
  143. //
  144. // Returns whether the current FunctionalTensorWrapper
  145. // experienced a set_() call.
  146. bool was_storage_changed() {
  147. return was_storage_changed_;
  148. }
  149. c10::SymInt get_storage_size(bool before) {
  150. return functional_storage_impl()->get_storage_size(before);
  151. }
  152. // Returns whether the FunctionalTensor experienced an
  153. // untyped_storage().resize_() call
  154. bool was_inductor_storage_resized() {
  155. return functional_storage_impl()->was_inductor_storage_resized();
  156. }
  157. // The functionalization pass can be used to remove mutations.
  158. // It does so by replacing any mutation op with it's corresponding
  159. // out-of-place op, followed by a call to replace_(). e.g:
  160. //
  161. // a.add_(1)
  162. //
  163. // will turn into:
  164. //
  165. // tmp = a.add(1)
  166. // a.replace_(tmp)
  167. //
  168. // replace_() swaps out the wrapped tensor, value_, with tmp.
  169. void replace_(const Tensor& other, bool from_lazy_regenerate = false);
  170. bool is_multi_output_view() {
  171. return is_multi_output_view_;
  172. }
  173. // See Note[resize_() in functionalization pass]
  174. void maybe_replace_storage(const Tensor& other);
  175. // Replaces the storage with a new functional storage,
  176. // and clears the view_metas_ stack.
  177. // WARNING: Calling this function will sever the aliasing relationship between
  178. // the current FunctionalTensorWrapper and any of its outstanding aliases.
  179. // Please only call if you know what you're doing.
  180. void _unsafe_reset_storage();
  181. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  182. const c10::VariableVersion& version_counter,
  183. bool allow_tensor_metadata_change) const override;
  184. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  185. c10::VariableVersion&& version_counter,
  186. bool allow_tensor_metadata_change) const override;
  187. ~FunctionalTensorWrapper() override = default;
  188. // FunctionalTensorWrapper overrides all custom size/stride function,
  189. // so that if the inner tensor has a custom implementation
  190. // we make sure to call that implementation.
  191. at::IntArrayRef sizes_custom() const override;
  192. at::IntArrayRef strides_custom() const override;
  193. int64_t dim_custom() const override;
  194. int64_t numel_custom() const override;
  195. bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
  196. c10::SymIntArrayRef sym_sizes_custom() const override;
  197. c10::SymInt sym_size_custom(int64_t d) const override;
  198. c10::SymIntArrayRef sym_strides_custom() const override;
  199. c10::SymInt sym_storage_offset_custom() const override;
  200. c10::Device device_custom() const override;
  201. private:
  202. const char* tensorimpl_type_name() const override;
  203. void set_constructor_metadata();
  204. functionalization::FunctionalStorageImpl* functional_storage_impl() const;
  205. // This is used to re-implement shallow_copy_and_detach for
  206. // FunctionalTensorWrapper. The implementation is identical, but we just need
  207. // to return a subclass instead of a plain TensorImpl.
  208. // TODO: maybe it's possible to arrange for that to happen automatically
  209. // without an override here?
  210. template <typename VariableVersion>
  211. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  212. VariableVersion&& version_counter,
  213. bool allow_tensor_metadata_change) const;
  214. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
  215. void copy_tensor_metadata_and_refresh(
  216. const FunctionalTensorWrapper* src_impl,
  217. FunctionalTensorWrapper* dest_impl,
  218. const c10::VariableVersion& version_counter,
  219. bool allow_tensor_metadata_change) const;
  220. // Note that value is not taken by reference: internally, the wrapper will
  221. // change the value tensor that it points to over time.
  222. Tensor value_;
  223. int64_t level_{};
  224. // These two counters are used for identifying
  225. // whether all the mutations on a given tensor are hidden from autograd or
  226. // not. If we have an input mutation that is hidden from autograd, then once
  227. // we convert the input mutation to a copy_() we know it will be safe to hide
  228. // the copy_() from autograd as well.
  229. bool has_metadata_mutation_ = false;
  230. bool is_multi_output_view_ = false;
  231. // Did the tensor experience a set_() call.
  232. bool was_storage_changed_ = false;
  233. // Did the tensor experience any view operation with symbolic int.
  234. bool is_symbolic_ = false;
  235. size_t generation_ = 0;
  236. std::vector<at::functionalization::ViewMeta> view_metas_;
  237. protected:
  238. static void copy_tensor_metadata(
  239. const FunctionalTensorWrapper* src_impl,
  240. FunctionalTensorWrapper* dest_impl,
  241. const c10::VariableVersion& version_counter,
  242. bool allow_tensor_metadata_change);
  243. };
  244. // Utility functions for the functionalization pass.
  245. namespace functionalization {
  246. namespace impl {
  247. TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
  248. const Tensor& tensor) {
  249. auto functional_impl =
  250. static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
  251. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
  252. return functional_impl;
  253. }
  254. TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
  255. TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
  256. TORCH_API bool isFunctionalTensor(
  257. const c10::List<std::optional<Tensor>>& t_list);
  258. TORCH_API bool isFunctionalTensor(ITensorListRef list);
  259. TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
  260. TORCH_API std::optional<Tensor> to_functional_tensor(
  261. const std::optional<Tensor>& tensor);
  262. TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
  263. const c10::List<std::optional<Tensor>>& t_list);
  264. TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
  265. TORCH_API void freeze_functional_tensor(const Tensor& tensor);
  266. TORCH_API Tensor
  267. from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
  268. TORCH_API std::optional<Tensor> from_functional_tensor(
  269. const std::optional<Tensor>& t,
  270. bool assert_functional = true);
  271. TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
  272. const c10::List<std::optional<Tensor>>& t_list);
  273. TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
  274. TORCH_API void sync(const at::Tensor& t);
  275. TORCH_API void sync(const std::optional<Tensor>& t);
  276. TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
  277. TORCH_API void sync(ITensorListRef t_list);
  278. TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
  279. TORCH_API void replace_(
  280. const ITensorListRef functional_tensor,
  281. ITensorListRef other);
  282. TORCH_API void commit_update(const Tensor& functional_tensor);
  283. TORCH_API void commit_update(ITensorListRef functional_tensor);
  284. TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
  285. TORCH_API void mark_mutation_hidden_from_autograd(
  286. const Tensor& functional_tensor);
  287. TORCH_API bool are_all_mutations_hidden_from_autograd(
  288. const Tensor& functional_tensor);
  289. TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
  290. const Tensor& functional_tensor);
  291. // These two methods are XLA-specific logic and are no-ops
  292. // for the normal functionalization flow.
  293. TORCH_API void propagate_xla_data(
  294. const Tensor& functional_tensor,
  295. const Tensor& other);
  296. TORCH_API void propagate_xla_data(
  297. const ITensorListRef functional_tensor,
  298. ITensorListRef other);
  299. Tensor create_functional_tensor_with_view_meta(
  300. const Tensor& view_to_wrap,
  301. const Tensor& base,
  302. functionalization::ViewMeta meta,
  303. int64_t out_idx = 0);
  304. std::vector<Tensor> create_functional_tensor_with_view_meta(
  305. ITensorListRef view_to_wrap,
  306. const Tensor& base,
  307. const functionalization::ViewMeta& meta);
  308. void mutate_view_meta(
  309. const Tensor& self,
  310. const functionalization::ViewMeta& meta);
  311. void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
  312. void set_sizes_strides_offset(
  313. const std::vector<Tensor>& outs,
  314. const std::vector<Tensor>& meta_outs);
  315. // ~~~~~ TLS used in functionalization ~~~~~
  316. TORCH_API bool getFunctionalizationReapplyViewsTLS();
  317. TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
  318. class TORCH_API FunctionalizationReapplyViewsGuard {
  319. public:
  320. FunctionalizationReapplyViewsGuard(bool reapply_views)
  321. : prev_(getFunctionalizationReapplyViewsTLS()) {
  322. setFunctionalizationReapplyViewsTLS(reapply_views);
  323. }
  324. ~FunctionalizationReapplyViewsGuard() {
  325. setFunctionalizationReapplyViewsTLS(prev_);
  326. }
  327. FunctionalizationReapplyViewsGuard(
  328. const FunctionalizationReapplyViewsGuard&) = delete;
  329. FunctionalizationReapplyViewsGuard operator=(
  330. const FunctionalizationReapplyViewsGuard&) = delete;
  331. FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
  332. delete;
  333. FunctionalizationReapplyViewsGuard operator=(
  334. FunctionalizationReapplyViewsGuard&&) = delete;
  335. private:
  336. bool prev_;
  337. };
  338. } // namespace impl
  339. // Helper function to call an out-of-place composite aten kernel that may use
  340. // mutations / views internally, and functionalize them.
  341. TORCH_API void functionalize_op_helper(
  342. const c10::OperatorHandle& op,
  343. torch::jit::Stack* stack);
  344. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  345. struct _functionalize_aten_op final {};
  346. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  347. struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
  348. static ReturnType call(
  349. typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
  350. using FuncType = ReturnType(
  351. typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
  352. auto op = c10::Dispatcher::singleton()
  353. .findSchemaOrThrow(
  354. (const char*)Op::name, (const char*)Op::overload_name)
  355. .typed<FuncType>();
  356. return c10::impl::BoxedKernelWrapper<FuncType>::call(
  357. c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
  358. op,
  359. // BoxedKernelWrapper knows to ignore this keyset argument,
  360. // because functionalize_op_helper doesn't take in a DispatchKeySet
  361. c10::DispatchKeySet(),
  362. args...);
  363. }
  364. };
  365. template <class Op>
  366. using functionalize_aten_op =
  367. _functionalize_aten_op<Op, false, typename Op::schema>;
  368. template <class Op>
  369. using functionalize_aten_op_symint =
  370. _functionalize_aten_op<Op, true, typename Op::schema>;
  371. } // namespace functionalization
  372. } // namespace at