functional_tensor.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import warnings
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, Union
  6. import torch
  7. import torch.utils._pytree as pytree
  8. from torch._C import _functionalization_reapply_views_tls as _reapply_views
  9. from torch._ops import _get_dispatch_mode_pre_dispatch
  10. from torch.utils._python_dispatch import (
  11. _detect_infra_mode,
  12. _disable_infra_mode,
  13. return_and_correct_aliasing,
  14. TorchDispatchMode,
  15. )
  16. not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
  17. # NOTE Some special handling for tensor conversion during export is needed.
  18. # Normally, when tracing through the model with tensor.to(), the maybe-aliasing
  19. # relationship between input and output tensors will be baked into the graph.
  20. # For example, if we got a tensor with device cpu and call tensor.to("cpu"),
  21. # it will become a no-op in the graph. For a whole graph capture, this is not
  22. # sound so we need to do something different. Instead, in export we will try to
  23. # preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy
  24. # operator to be traced in the graph, and subsequently banning mutations on all
  25. # such converted tensors.
  26. # In addition to patching .to() method call in functionalization, we will have to
  27. # patch other similar methods like float() and cpu(), because they intentionally
  28. # don't fall back to .to() methods, but have the same behavior as .to() according to
  29. # pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html
  30. # thus we simply force them to go through .to() call.
  31. def _conversion_method_template(**extra_kwargs):
  32. def _(self, *args, **kwargs):
  33. return self.to(*args, **{**kwargs, **extra_kwargs})
  34. return _
  35. class FunctionalTensor(torch.Tensor):
  36. """
  37. Functional tensors represent tensors that will remove mutations
  38. from a program. If you perform a mutable operation on a functional tensor,
  39. it will re-dispatch to the functional variant of that operation.
  40. Historically, functionalization is implemented in C++ in the dispatcher.
  41. This class is a lightweight python shim around the C++ functionalization logic.
  42. FunctionalTensor is required to be used with a corresponding
  43. FunctionalTensormode active, because it relies
  44. on using the mode for dispatch (which can properly handle factory functions).
  45. """
  46. elem: torch.Tensor
  47. # Indicates to our torch_dispatch dispatching infra that
  48. # this is an "infra" mode with lower dispatching precedence.
  49. _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
  50. # Note: The reason we add these extra keys to our FunctionalTensor subclass
  51. # is to mirror the behavior of C++ functionalization (we can choose to change this
  52. # later, as long as it doesn't break anything).
  53. # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor
  54. # to the wrapper, excluding functorch and python dispatch keys.
  55. # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy,
  56. # except that they don't include ZeroTensor so I'm manually adding it in.
  57. _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add(
  58. torch._C.DispatchKey.ZeroTensor
  59. )
  60. # These are all aten ops that correspond to metadata queries.
  61. # We want FunctionalTensor to be able to handle them directly.
  62. metadata_fns = [
  63. torch.ops.aten.is_contiguous.default, # type: ignore[has-type]
  64. torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type]
  65. torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type]
  66. torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type]
  67. torch.ops.aten.size.default, # type: ignore[has-type]
  68. torch.ops.aten.sym_size.default, # type: ignore[has-type]
  69. torch.ops.aten.stride.default, # type: ignore[has-type]
  70. torch.ops.aten.sym_stride.default, # type: ignore[has-type]
  71. torch.ops.aten.storage_offset.default, # type: ignore[has-type]
  72. torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type]
  73. torch.ops.aten.numel.default, # type: ignore[has-type]
  74. torch.ops.aten.sym_numel.default, # type: ignore[has-type]
  75. torch.ops.aten.dim.default, # type: ignore[has-type]
  76. torch.ops.prim.device.default, # type: ignore[has-type]
  77. ]
  78. # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
  79. # TODO (tmanlaibaatar) make it a tag
  80. maybe_aliasing_or_mutating_ops = [
  81. torch.ops.aten.dropout.default, # type: ignore[has-type]
  82. torch.ops.aten.batch_norm.default, # type: ignore[has-type]
  83. torch.ops.aten.native_batch_norm.default, # type: ignore[has-type]
  84. torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type]
  85. torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type]
  86. torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type]
  87. ]
  88. def __new__(cls, elem):
  89. assert torch._is_functional_tensor(elem)
  90. # In general, we'd like our functional tensor subclass to only be in charge of functionalization,
  91. # and defer to the inner subclass for all other functionality.
  92. # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback
  93. # until after we redispatch to our inner ZeroTensor.
  94. # However, there are a few keys that we need to mirror between the inner and outer tensors.
  95. # Conjugate
  96. # Negative
  97. # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`.
  98. # We **need** calls to is_conj() to return the same thing on the outer and inner tensors,
  99. # Because user code / framework code that branches like so needs to do the same thing
  100. # when it sees the outer FunctionalTensor:
  101. # if (x.is_conj()) {
  102. # return at::view_as_real(x.resolve_conj());
  103. # } else {
  104. # return at::view_as_real(x);
  105. # }
  106. extra_dispatch_keys = (
  107. FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem)
  108. )
  109. out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined]
  110. # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
  111. # Calling the overload that has kwargs causes us to go down the first overload path,
  112. # which will **always** specialize sizes.
  113. # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
  114. cls,
  115. elem.shape, # sizes
  116. elem.stride(), # strides
  117. elem.storage_offset(), # storage_offset
  118. None, # memory_format
  119. elem.dtype, # dtype
  120. elem.layout, # layout
  121. elem.device, # device
  122. False, # pin_memory
  123. elem.requires_grad, # requires_grad
  124. "sizes", # dispatch_sizes_strides_policy
  125. False, # dispatch_device
  126. False, # dispatch_layout
  127. extra_dispatch_keys, # _extra_dispatch_keys
  128. )
  129. torch._C._set_throw_on_mutable_data_ptr(out)
  130. out.elem = elem
  131. return out
  132. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  133. unrecognized_types = [
  134. t
  135. for t in types
  136. if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor]
  137. ]
  138. if unrecognized_types:
  139. not_implemented_log.debug(
  140. "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
  141. )
  142. return NotImplemented
  143. if kwargs is None:
  144. kwargs = {}
  145. # FunctionalTensor needs to plumb all metadata requests to the inner tensor.
  146. # In theory we don't have to do this - but if we want to service metadata requests here,
  147. # we need to carefully make sure all metadata is accurate (including metadata mutations)
  148. if func in FunctionalTensor.metadata_fns:
  149. # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry
  150. # about the problem of keeping metadata in sync between the wrapper and inner tensor.
  151. # This also alleviates us from having to manually handle metadata mutations on the wrapper.
  152. assert len(kwargs) == 0
  153. if func in [
  154. torch.ops.aten.is_strides_like_format.default,
  155. torch.ops.aten.is_contiguous.memory_format,
  156. ]:
  157. assert len(args) == 2 and isinstance(args[0], FunctionalTensor)
  158. return func(args[0].elem, args[1])
  159. assert len(args) == 1 and isinstance(args[0], FunctionalTensor)
  160. return func(args[0].elem)
  161. # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up:
  162. # - _make_wrapper_subclass requires a __torch_dispatch__
  163. # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor,
  164. # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper.
  165. # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(),
  166. # which causes every subclass created above autograd to have autograd view metadata
  167. # (in addition to also being a FunctionalTensorWrapper).
  168. raise RuntimeError(
  169. "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
  170. )
  171. def __repr__(self):
  172. return f"FunctionalTensor({repr(self.elem)})"
  173. @staticmethod
  174. def to_functional(x):
  175. # We will do the wrapping for the user.
  176. assert not torch._is_functional_tensor(x)
  177. # The only autograd metadata we care about on the FunctionalTensor is:
  178. # - requires_grad (so autograd runs)
  179. # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
  180. # this is handled by FunctionalTensor.to_functional
  181. x_functional = torch._to_functional_tensor(x)
  182. # Technically the FunctionalTensormode here is unnecessary,
  183. # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing.
  184. # _mirror_autograd_meta_to queries tensor sizes,
  185. # and otherwise the sym_size() call will go to the proxy mode before hitting
  186. # FunctionalTensor.__torch_dispatch__
  187. functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
  188. assert functional_mode is not None
  189. with functional_mode:
  190. torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined]
  191. out = FunctionalTensor(x_functional)
  192. torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined]
  193. return out
  194. def from_functional(self):
  195. torch._sync(self)
  196. return torch._from_functional_tensor(self.elem)
  197. def replace_(self, output) -> None:
  198. torch._functionalize_replace(self.elem, output)
  199. def commit_update(self) -> None:
  200. torch._functionalize_commit_update(self.elem)
  201. def sync(self) -> None:
  202. torch._functionalize_sync(self.elem)
  203. def mark_mutation_hidden_from_autograd(self) -> None:
  204. torch._functionalize_mark_mutation_hidden_from_autograd(self.elem)
  205. def tolist(self) -> Any:
  206. if self.elem.dim() == 0:
  207. return self.elem.item()
  208. elif self.elem.dim() == 1:
  209. return [elem.item() for elem in self.elem]
  210. else:
  211. return [elem.tolist() for elem in self.elem]
  212. def to(self, *args, **kwargs):
  213. if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export:
  214. # If copy is specified as pos arg, it's always the second one.
  215. if len([arg for arg in args if isinstance(arg, bool)]) <= 1:
  216. return super().to(*args, **{**kwargs, "copy": True})
  217. return super().to(*args, **kwargs)
  218. def cuda(self, device=None, *args, **kwargs):
  219. device = device or torch.cuda.current_device()
  220. if len(args) > 0:
  221. return self.to(device, *args, **kwargs)
  222. else:
  223. return self.to(device=device, **kwargs)
  224. char = _conversion_method_template(dtype=torch.int8)
  225. cpu = _conversion_method_template(device=torch.device("cpu"))
  226. bfloat16 = _conversion_method_template(dtype=torch.bfloat16)
  227. byte = _conversion_method_template(dtype=torch.uint8)
  228. double = _conversion_method_template(dtype=torch.float64)
  229. float = _conversion_method_template(dtype=torch.float32)
  230. bool = _conversion_method_template(dtype=torch.bool)
  231. half = _conversion_method_template(dtype=torch.float16)
  232. int = _conversion_method_template(dtype=torch.int32)
  233. long = _conversion_method_template(dtype=torch.int64)
  234. class FunctionalTensorMode(TorchDispatchMode):
  235. def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
  236. self.export = export
  237. self.is_on_stack = False
  238. self.enter_stack = []
  239. # Indicates to our torch_dispatch dispatching infra that
  240. # this is an "infra" mode with lower dispatching precedence.
  241. self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
  242. self.pre_dispatch = pre_dispatch
  243. # This will be turned off later for pre-dispatch functionalization
  244. self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
  245. # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
  246. # track of the ordering between side effectful operations.
  247. self._tokens: Dict[Any, torch.Tensor] = {}
  248. # Functionalization runs twice in AOTAutograd, once in
  249. # `run_functionalized_fw_and_collect_metadata` to collect metadata to
  250. # see which tensors need to be functionalized and discover how many
  251. # tokens we need, and another time in `make_fx` which does the actual
  252. # tracing to replace ops with their functional variants and handling
  253. # side-effectful ops. In the second stage there should be no token
  254. # discovery. This flag distinguishes between the two stages.
  255. self._allow_token_discovery = _allow_token_discovery
  256. # No-op if FunctionalTensorMode is already in use
  257. def __enter__(self):
  258. def _get_prev_mode():
  259. if self._dispatch_key == torch._C.DispatchKey.PreDispatch:
  260. return _get_dispatch_mode_pre_dispatch(
  261. torch._C._TorchDispatchModeKey.FUNCTIONAL
  262. )
  263. return torch._C._get_dispatch_mode(
  264. torch._C._TorchDispatchModeKey.FUNCTIONAL
  265. )
  266. if _get_prev_mode() is None:
  267. self.enter_stack.append(True)
  268. return super().__enter__()
  269. else:
  270. self.enter_stack.append(False)
  271. return self
  272. def __exit__(self, a, b, c):
  273. is_on_stack = self.enter_stack.pop()
  274. if is_on_stack:
  275. super().__exit__(a, b, c)
  276. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  277. if kwargs is None:
  278. kwargs = {}
  279. unrecognized_types = [
  280. t
  281. for t in types
  282. if not issubclass(t, torch._subclasses.FakeTensor)
  283. and t not in [torch.Tensor, FunctionalTensor]
  284. ]
  285. if unrecognized_types:
  286. not_implemented_log.debug(
  287. "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
  288. )
  289. return NotImplemented
  290. def _can_decompose(func):
  291. # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832
  292. # We never decompose dropout in export
  293. if self.export and func == torch.ops.aten.dropout.default:
  294. return False
  295. # TODO (tmanlaibaatar)
  296. # Eventually, we don't want to decompose any aten op at all
  297. # but there is a safety and coverage gap that we need to close
  298. # before that.
  299. #
  300. # (1) the "safety" is what we are risking with this PR
  301. # (we are blindly taking every op that advertises as
  302. # functional and sending it to the functional fallback.
  303. # We risk silent correctness if we have an op that lies about its schema,
  304. # that we didn't manually hardcode above) Therefore we always decompose them
  305. # (2) the "not every composite inplace op has a functional variant" is a coverage gap,
  306. # but not really a safety risk, since we'll loudly error when we try to generate
  307. # functionalization kernels for these new (composite) inplace/view ops. But until we
  308. # establish such gap more concretely, we still decompose them
  309. if self._dispatch_key is not None:
  310. # it is unsafe to not decompose ops that claim to be functional but actually aren't
  311. if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
  312. return True
  313. # only decompose view or inplace mutating ops
  314. alias_info = len(
  315. [i for i in func._schema.arguments if i.alias_info is not None]
  316. )
  317. should_decompose = alias_info != 0 or func._schema.is_mutable
  318. if not should_decompose:
  319. if func.namespace not in ["aten", "prim"]:
  320. warnings.warn(
  321. f"At pre-dispatch tracing, we will assume that any "
  322. f"custom op that is marked with CompositeImplicitAutograd "
  323. f"and functional are safe to not decompose. We found {func}"
  324. f" to be one such op."
  325. )
  326. return should_decompose
  327. return True
  328. if (
  329. func not in FunctionalTensor.metadata_fns
  330. and _can_decompose(func)
  331. # Not all funcs from __torch_dispatch__ are actual dispatcher ops,
  332. # e.g. prim.device
  333. and torch._C._dispatch_has_kernel(func.name())
  334. ):
  335. with self:
  336. r = func.decompose(*args, **kwargs)
  337. if r is not NotImplemented:
  338. return r
  339. def assert_is_functional(x):
  340. assert torch._is_functional_tensor(x)
  341. def wrap(x):
  342. # Only wrap our outputs in subclasses if the inner functionalization call
  343. # also wrapped outputs into FunctionalTensorWrappers.
  344. # When can this happen? e.g. `torch.div(2, 2)`
  345. assert not isinstance(x, FunctionalTensor)
  346. if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
  347. return FunctionalTensor(x)
  348. return x
  349. def unwrap(x):
  350. return x.elem
  351. from torch._higher_order_ops.auto_functionalize import (
  352. can_auto_functionalize,
  353. do_auto_functionalize,
  354. )
  355. if can_auto_functionalize(
  356. func
  357. ) and not torch._C._dispatch_has_kernel_for_dispatch_key(
  358. func.name(), torch._C.DispatchKey.Functionalize
  359. ):
  360. # it doesn't matter what mode we use here because
  361. # the implementation of do_auto_functionalize doesn't
  362. # interact with FunctionalTensorMode at all
  363. return do_auto_functionalize(func, args, kwargs)
  364. from torch._higher_order_ops.effects import handle_effects, has_effects
  365. if has_effects(func, args, kwargs):
  366. assert not torch._C._dispatch_has_kernel_for_dispatch_key(
  367. func.name(), torch._C.DispatchKey.Functionalize
  368. )
  369. return handle_effects(
  370. self._allow_token_discovery, self._tokens, func, args, kwargs
  371. )
  372. args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
  373. FunctionalTensor, unwrap, (args, kwargs)
  374. )
  375. # Expectation: functionalization should not **already** be enabled above our mode.
  376. # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization
  377. # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper.
  378. is_included = torch._C._dispatch_tls_is_dispatch_key_included(
  379. torch._C.DispatchKey.Functionalize
  380. )
  381. is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
  382. torch._C.DispatchKey.Functionalize
  383. )
  384. assert is_excluded or not is_included
  385. include_to_set = (
  386. torch._C._dispatch_tls_local_include_set()
  387. | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
  388. )
  389. exclude_to_set = (
  390. torch._C._dispatch_tls_local_exclude_set().remove(
  391. torch._C.DispatchKey.Functionalize
  392. )
  393. - FunctionalTensor._extra_dispatch_keys
  394. )
  395. # All we want to do here is re-use the existing C++ functionalization logic.
  396. # This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
  397. with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
  398. try:
  399. # By default for python functionalization (for AOTAutograd), we reapply views.
  400. old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
  401. # Sometimes these functions cannot be directly dispatched to functionalize key
  402. # because args are sometimes not functional tensors for some reason?
  403. if func in FunctionalTensor.metadata_fns:
  404. outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
  405. outs_wrapped = pytree.tree_map_only(
  406. torch.Tensor, wrap, outs_unwrapped
  407. )
  408. else:
  409. # When we dispatch to the C++ functionalization kernel, we might need to jump back to the
  410. # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
  411. # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
  412. # from the TLS in order to avoid infinite looping, but this would prevent us from coming
  413. # back to PreDispatch later
  414. outs_unwrapped = func._op_dk(
  415. torch._C.DispatchKey.Functionalize,
  416. *args_unwrapped,
  417. **kwargs_unwrapped,
  418. )
  419. # We don't allow any mutation on result of dropout or _to_copy
  420. if self.export:
  421. if func in (
  422. torch.ops.aten.dropout.default,
  423. torch.ops.aten._to_copy.default,
  424. ):
  425. torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
  426. outs_wrapped = pytree.tree_map_only(
  427. torch.Tensor, wrap, outs_unwrapped
  428. )
  429. finally:
  430. torch._disable_functionalization()
  431. torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
  432. is_included = torch._C._dispatch_tls_is_dispatch_key_included(
  433. torch._C.DispatchKey.Functionalize
  434. )
  435. is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
  436. torch._C.DispatchKey.Functionalize
  437. )
  438. assert is_excluded or not is_included
  439. if (
  440. # If no outputs are our functional subclass, then don't try to fix up aliasing
  441. not any(
  442. isinstance(x, FunctionalTensor)
  443. for x in pytree.tree_leaves(outs_wrapped)
  444. )
  445. # Since lift_fresh lifts its argument into a functional tensor, we can skip the
  446. # aliasing correction step. Otherwise, we would be setting the storage of a
  447. # lifted tensor to that of an unlifted tensor.
  448. # Ref: https://github.com/pytorch/pytorch/issues/111506
  449. or func == torch.ops.aten.lift_fresh.default
  450. ):
  451. return outs_wrapped
  452. # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing.
  453. # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects.
  454. # Use this util to figure out the right thing to return.
  455. # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for.
  456. return return_and_correct_aliasing(func, args, kwargs, outs_wrapped)
  457. @contextlib.contextmanager
  458. def disable_functional_mode():
  459. return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
  460. # This is similar to torch.func.functionalize, but:
  461. # - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass).
  462. # One important advantage to using this mode is that it will let us
  463. # run functionalization underneath __torch_dispatch__,
  464. # which we need in AOTAutograd.
  465. # - Doing so means that it does not automatically compose with other
  466. # functorch transforms, since these transforms always run above __torch_dispatch__.
  467. # That's why this util lives here, and not in functorch.
  468. def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()):
  469. # TODO: pull these from aot autograd
  470. def to_fun(t):
  471. if isinstance(t, torch.Tensor):
  472. return FunctionalTensor.to_functional(t)
  473. return t
  474. def from_fun(t):
  475. if not isinstance(t, FunctionalTensor):
  476. # quick sanity assert
  477. if isinstance(t, torch.Tensor):
  478. assert not torch._is_functional_tensor(t)
  479. return t
  480. torch._sync(t)
  481. return torch._from_functional_tensor(t.elem)
  482. def inner(*args, **kwargs):
  483. disable_above = torch._C._ExcludeDispatchKeyGuard(
  484. torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
  485. )
  486. with disable_above, mode:
  487. func_args = pytree.tree_map_only(torch.Tensor, to_fun, args)
  488. func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs)
  489. func_outputs = func(*func_args, **func_kwargs)
  490. outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs)
  491. return outputs
  492. return inner
  493. class BaseFunctionalizeAPI(ABC):
  494. @abstractmethod
  495. def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
  496. pass
  497. @abstractmethod
  498. def unwrap_tensors(
  499. self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
  500. ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
  501. pass
  502. @abstractmethod
  503. def functionalize(self, inner_f: Callable) -> Callable:
  504. pass
  505. @abstractmethod
  506. def redispatch_to_next(self) -> ContextManager:
  507. pass
  508. @abstractmethod
  509. def replace(self, input_tensor, output_tensor) -> None:
  510. pass
  511. @abstractmethod
  512. def commit_update(self, tensor) -> None:
  513. pass
  514. @abstractmethod
  515. def sync(self, tensor) -> None:
  516. pass
  517. @abstractmethod
  518. def mark_mutation_hidden_from_autograd(self, tensor) -> None:
  519. pass
  520. class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
  521. def __init__(
  522. self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False
  523. ) -> None:
  524. super().__init__()
  525. self.mode = mode if mode else FunctionalTensorMode()
  526. self.pre_dispatch = pre_dispatch
  527. def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
  528. with self.mode:
  529. return torch.utils._pytree.tree_map_only(
  530. torch.Tensor, FunctionalTensor.to_functional, args
  531. )
  532. def unwrap_tensors(
  533. self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
  534. ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
  535. return torch.utils._pytree.tree_map_only(
  536. FunctionalTensor, FunctionalTensor.from_functional, args
  537. )
  538. def functionalize(self, inner_f: Callable) -> Callable:
  539. return dispatch_functionalize(inner_f, self.mode)
  540. def redispatch_to_next(self) -> ContextManager:
  541. # [NOTE] We don't do anything here because at the time
  542. # we exercise this path, we would have already popped the
  543. # FunctionalTensorMode from mode stack. Since FunctionalTensorMode
  544. # is now stateful, it is better to explicitly pass in correct mode
  545. # directly instead of globally setting it.
  546. return contextlib.nullcontext()
  547. def replace(self, input_tensor, output_tensor) -> None:
  548. assert isinstance(input_tensor, FunctionalTensor)
  549. assert not isinstance(output_tensor, FunctionalTensor)
  550. input_tensor.replace_(output_tensor)
  551. def commit_update(self, tensor) -> None:
  552. assert isinstance(tensor, FunctionalTensor)
  553. tensor.commit_update()
  554. def sync(self, tensor) -> None:
  555. assert isinstance(tensor, FunctionalTensor)
  556. tensor.sync()
  557. def mark_mutation_hidden_from_autograd(self, tensor) -> None:
  558. assert isinstance(tensor, FunctionalTensor)
  559. tensor.mark_mutation_hidden_from_autograd()
  560. class CppFunctionalizeAPI(BaseFunctionalizeAPI):
  561. def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
  562. from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
  563. return _wrap_all_tensors_to_functional(args, level=0)
  564. def unwrap_tensors(
  565. self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
  566. ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
  567. from torch._functorch.eager_transforms import (
  568. _unwrap_all_tensors_from_functional,
  569. )
  570. return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views())
  571. def functionalize(self, inner_f: Callable) -> Callable:
  572. return torch.func.functionalize(inner_f)
  573. def redispatch_to_next(self) -> ContextManager:
  574. return torch._C._ExcludeDispatchKeyGuard(
  575. torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
  576. )
  577. def replace(self, input_tensor, output_tensor) -> None:
  578. torch._functionalize_replace(input_tensor, output_tensor)
  579. def commit_update(self, tensor) -> None:
  580. torch._functionalize_commit_update(tensor)
  581. def sync(self, tensor) -> None:
  582. torch._functionalize_sync(tensor)
  583. def mark_mutation_hidden_from_autograd(self, tensor) -> None:
  584. torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
  585. class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
  586. def __init__(self, interpreter):
  587. self.interpreter = interpreter
  588. def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
  589. from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
  590. return _wrap_all_tensors_to_functional(args, level=self.interpreter.level())
  591. def unwrap_tensors(
  592. self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
  593. ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
  594. from torch._functorch.eager_transforms import (
  595. _unwrap_all_tensors_from_functional,
  596. )
  597. return _unwrap_all_tensors_from_functional(
  598. args, reapply_views=self.interpreter.functionalize_add_back_views()
  599. )
  600. def functionalize(self, inner_f: Callable) -> Callable:
  601. return torch.func.functionalize(
  602. inner_f,
  603. remove="mutations_and_views"
  604. if self.interpreter.functionalize_add_back_views()
  605. else "mutations",
  606. )
  607. def redispatch_to_next(self) -> ContextManager:
  608. return self.interpreter.lower()
  609. def replace(self, input_tensor, output_tensor) -> None:
  610. torch._functionalize_replace(input_tensor, output_tensor)
  611. def commit_update(self, tensor) -> None:
  612. torch._functionalize_commit_update(tensor)
  613. def sync(self, tensor) -> None:
  614. torch._functionalize_sync(tensor)
  615. def mark_mutation_hidden_from_autograd(self, tensor) -> None:
  616. torch._functionalize_mark_mutation_hidden_from_autograd(tensor)