meta_utils.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import warnings
  6. import weakref
  7. from dataclasses import dataclass
  8. from typing import (
  9. Any,
  10. Callable,
  11. ClassVar,
  12. ContextManager,
  13. Dict,
  14. List,
  15. Optional,
  16. Tuple,
  17. Type,
  18. TYPE_CHECKING,
  19. Union,
  20. )
  21. from typing_extensions import TypeAlias
  22. import torch
  23. from torch._C._autograd import CreationMeta
  24. from torch._C._functorch import (
  25. _add_batch_dim,
  26. _unwrap_functional_tensor,
  27. _wrap_functional_tensor,
  28. get_unwrapped,
  29. is_batchedtensor,
  30. is_functorch_wrapped_tensor,
  31. is_gradtrackingtensor,
  32. is_legacy_batchedtensor,
  33. maybe_get_bdim,
  34. maybe_get_level,
  35. peek_interpreter_stack,
  36. )
  37. from torch._logging import trace_structured
  38. from torch.utils._mode_utils import no_dispatch
  39. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  40. from torch.utils.weak import WeakIdKeyDictionary
  41. if TYPE_CHECKING:
  42. from torch._C._functorch import CInterpreter
  43. from torch._guards import Source
  44. # Import here to avoid cycle
  45. from torch._subclasses.fake_tensor import FakeTensorMode
  46. # Import the following modules during type checking to enable code intelligence features,
  47. # Do not import unconditionally, as they import sympy and importing sympy is very slow
  48. from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
  49. DimList = List
  50. def safe_is_leaf(t):
  51. try:
  52. return t.is_leaf
  53. except RuntimeError:
  54. # inference mode can trigger this
  55. return False
  56. def safe_grad(t):
  57. with warnings.catch_warnings():
  58. warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
  59. return t.grad
  60. def assert_eq(a, b):
  61. assert a == b, f"{a} != {b}"
  62. def assert_metadata_eq(
  63. assert_eq,
  64. m1: Union[MetaTensorDesc, torch.Tensor],
  65. m2: torch.Tensor,
  66. *,
  67. skip_symbolic=False,
  68. skip_leaf=False,
  69. ):
  70. if isinstance(m1, torch.Tensor):
  71. m1 = MetaTensorDescriber().describe_tensor(m1)
  72. def go(m1, m2):
  73. assert_eq(m1.dtype, m2.dtype)
  74. if not skip_symbolic:
  75. assert_eq(m1.shape, m2.shape)
  76. assert_eq(m1.requires_grad, m2.requires_grad)
  77. if not skip_leaf:
  78. assert_eq(m1.is_leaf, m2.is_leaf)
  79. # MetaTensorDesc doesn't store grad_fn; inferred from leaf
  80. # assert_eq(m1.grad_fn is None, m2.grad_fn is None)
  81. assert_eq(m1.is_sparse, m2.is_sparse)
  82. assert_eq(m1.is_inference, m2.is_inference())
  83. assert_eq(m1.is_conj, m2.is_conj())
  84. assert_eq(m1.is_neg, m2.is_neg())
  85. assert_eq(m1.grad is not None, safe_grad(m2) is not None)
  86. if m1.grad is not None:
  87. go(m1.grad, safe_grad(m2))
  88. if m1.is_sparse:
  89. assert_eq(m1.dense_dim, m2.dense_dim())
  90. assert_eq(m1.sparse_dim, m2.sparse_dim())
  91. assert_eq(m1.is_coalesced, m2.is_coalesced())
  92. else:
  93. if not skip_symbolic:
  94. assert_eq(m1.stride, m2.stride())
  95. assert_eq(m1.storage_offset, m2.storage_offset())
  96. assert_eq(m1.is_view, m2._is_view())
  97. if m1.is_view:
  98. go(m1.base, m2._base)
  99. # TODO: test if is resizable (no direct query for this atm)
  100. # TODO: audit AutogradMeta to see if it matches
  101. # TODO: test forward AD
  102. return go(m1, m2)
  103. def is_sparse_coo(t):
  104. return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
  105. def is_sparse_compressed_layout(layout):
  106. return layout in {
  107. torch.sparse_csr,
  108. torch.sparse_csc,
  109. torch.sparse_bsr,
  110. torch.sparse_bsc,
  111. }
  112. def is_sparse_compressed(t):
  113. return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
  114. def is_sparse_any(t):
  115. return is_sparse_coo(t) or is_sparse_compressed(t)
  116. # Don't use id() directly, because those can get reallocated over time.
  117. MetaStorageId: TypeAlias = int
  118. MetaTensorId: TypeAlias = int
  119. DESCRIBER_NEXT_ID = 0
  120. class MetaTensorDescriber:
  121. """
  122. Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
  123. for it, which is enough information to reconstruct a meta tensor/fake tensor
  124. corresponding to a Tensor as faithfully as possible.
  125. This is a stateful conversion object because we keep track of the IDs
  126. of the tensors/storages passed to us, so we can consistently give
  127. the same ID when we see the same tensor/storage.
  128. """
  129. def __init__(self, *, copy_data=False):
  130. global DESCRIBER_NEXT_ID
  131. self.id = DESCRIBER_NEXT_ID
  132. DESCRIBER_NEXT_ID += 1
  133. self.next_tensor_id: MetaTensorId = 0
  134. self.next_storage_id: MetaStorageId = 0
  135. # Tensor -> int
  136. self.lookup_tensor = WeakIdKeyDictionary()
  137. # Storage -> int
  138. self.lookup_storage = WeakIdKeyDictionary()
  139. self.copy_data = copy_data
  140. self.traced_tensors = set()
  141. self.traced_storages = set()
  142. def get_tensor_id(self, t: torch.Tensor):
  143. if t not in self.lookup_tensor:
  144. self.lookup_tensor[t] = self.next_tensor_id
  145. self.next_tensor_id += 1
  146. return self.lookup_tensor[t]
  147. def get_storage_id(self, s: torch.UntypedStorage):
  148. if s not in self.lookup_storage:
  149. self.lookup_storage[s] = self.next_storage_id
  150. self.next_storage_id += 1
  151. return self.lookup_storage[s]
  152. def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
  153. r = MetaStorageDesc(
  154. id=self.get_storage_id(s),
  155. size=s.size(),
  156. # NB: We don't do the copy yet; copy happens when we start
  157. # creating the new storages
  158. data=s if self.copy_data else None,
  159. )
  160. if trace and r.id not in self.traced_storages:
  161. trace_structured(
  162. "describe_storage",
  163. metadata_fn=lambda: r.as_json(self.id),
  164. )
  165. self.traced_storages.add(r.id)
  166. return r
  167. def describe_tensor(
  168. self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
  169. ):
  170. is_leaf = safe_is_leaf(t)
  171. is_view = t._is_view()
  172. is_sparse = t.is_sparse
  173. layout = t.layout
  174. is_nested = t.is_nested
  175. is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
  176. is_functorch_wrapped = is_functorch_wrapped_tensor(t)
  177. is_mkldnn = t.is_mkldnn
  178. is_batchedtensor_v = is_batchedtensor(t)
  179. is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
  180. is_gradtrackingtensor_v = is_gradtrackingtensor(t)
  181. is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v
  182. is_functional = torch._is_functional_tensor(t)
  183. storage = None
  184. # NB: For compatibility, I default this to zero, as sometimes people
  185. # still have stuffed zero into storage offset even though the tensor
  186. # doesn't meaningfully have an offset
  187. storage_offset = 0
  188. if not (
  189. is_sparse
  190. or is_sparse_compressed_layout(layout)
  191. or (is_nested and not is_traceable_wrapper_subclass_v)
  192. or is_mkldnn
  193. # TODO: TBH, functorch wrapped tensors probably should have
  194. # storage associated with them
  195. or is_functorch_wrapped
  196. or is_legacy_batchedtensor_v
  197. ):
  198. # NB: We actually don't use storage to do views, but might as well
  199. # put it in for accuracy
  200. storage = self.describe_storage(t.untyped_storage(), trace=trace)
  201. storage_offset = t.storage_offset()
  202. stride = None
  203. if not (
  204. is_sparse
  205. or is_sparse_compressed_layout(layout)
  206. or (is_nested and not is_traceable_wrapper_subclass_v)
  207. ):
  208. # stride/storage_offset are called from is_functorch_wrapped,
  209. # view_from_base, empty_create_subclass,
  210. # sym_sizes_strides_storage_offset (empty_create)
  211. stride = t.stride()
  212. # NB: this technically should refer to functorch unwrapped tensor, but
  213. # I am (perhaps abusively) using it to store both the functorch and
  214. # non-functorch functional tensor
  215. unwrapped = None
  216. autograd_meta_from = None
  217. current_level = None
  218. if is_batchedtensor_v or is_gradtrackingtensor_v:
  219. unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
  220. # xla and lazy tensors present as functional tensors, but we want them
  221. # to be handled specially
  222. elif is_functional and t.device.type not in ("xla", "lazy"):
  223. if t._is_view():
  224. raise RuntimeError(
  225. "Cannot safely fakify a view because this process drops the view information right now."
  226. )
  227. if not is_functorch_wrapped:
  228. torch._sync(t)
  229. unwrapped = self.describe_tensor(
  230. torch._from_functional_tensor(t), trace=trace
  231. )
  232. autograd_meta_from = t
  233. else:
  234. reapply_views = torch._C._functionalization_reapply_views_tls()
  235. # NB: has side effects!
  236. unwrapped = self.describe_tensor(
  237. _unwrap_functional_tensor(t, reapply_views), trace=trace
  238. )
  239. # TODO: It's pretty suspicious that functional tensors don't have
  240. # valid level and thus we just grab whatever the current level
  241. # is
  242. current_level = torch._C._functorch.current_level()
  243. maybe_functorch_stack = None
  244. if is_functorch_wrapped:
  245. with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
  246. pass
  247. attrs = None
  248. ctx = None
  249. type_v = None
  250. if is_traceable_wrapper_subclass_v:
  251. assert hasattr(t, "__tensor_flatten__")
  252. raw_attrs, ctx = t.__tensor_flatten__()
  253. attrs = {
  254. attr: self.describe_tensor(getattr(t, attr), trace=trace)
  255. for attr in raw_attrs
  256. }
  257. type_v = type(t)
  258. # TODO: Is it important to enable torch.inference_mode before querying
  259. # these values?
  260. r = MetaTensorDesc(
  261. id=self.get_tensor_id(t),
  262. storage=storage,
  263. is_inference=t.is_inference(),
  264. is_leaf=is_leaf,
  265. requires_grad=t.requires_grad,
  266. # NB: ndim should be OK too but there is a disaster at
  267. # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported
  268. # Actually, this means that we have a little bit of a problem
  269. # here, which is that there is some sensitivity to how exactly an
  270. # access is done if you have a __torch_function__ subclass. Maybe
  271. # should disable torch function before doing accesses?
  272. ndim=t.dim(),
  273. dtype=t.dtype,
  274. is_sparse=is_sparse,
  275. is_mkldnn=is_mkldnn,
  276. is_functorch_wrapped=is_functorch_wrapped,
  277. is_batchedtensor=is_batchedtensor_v,
  278. is_legacy_batchedtensor=is_legacy_batchedtensor_v,
  279. is_gradtrackingtensor=is_gradtrackingtensor_v,
  280. is_view=is_view,
  281. is_conj=t.is_conj(),
  282. is_neg=t.is_neg(),
  283. is_parameter=isinstance(t, torch.nn.Parameter),
  284. is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
  285. is_nested=is_nested,
  286. is_functional=is_functional,
  287. layout=layout,
  288. device=t.device,
  289. size=t.size(),
  290. stride=stride,
  291. storage_offset=storage_offset,
  292. dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
  293. sparse_dim=t.sparse_dim()
  294. if t.is_sparse or is_sparse_compressed(t)
  295. else None,
  296. dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
  297. is_coalesced=t.is_coalesced() if t.is_sparse else None,
  298. # TODO: I actually think recursing here is correct, but we have at
  299. # least an infinite cycle from base -> values -> base
  300. # https://github.com/pytorch/pytorch/issues/122089
  301. crow_indices=self.describe_tensor(
  302. t.crow_indices(), recurse=False, trace=trace
  303. )
  304. if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
  305. else None,
  306. col_indices=self.describe_tensor(
  307. t.col_indices(), recurse=False, trace=trace
  308. )
  309. if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
  310. else None,
  311. ccol_indices=self.describe_tensor(
  312. t.ccol_indices(), recurse=False, trace=trace
  313. )
  314. if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
  315. else None,
  316. row_indices=self.describe_tensor(
  317. t.row_indices(), recurse=False, trace=trace
  318. )
  319. if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
  320. else None,
  321. values=self.describe_tensor(t.values(), recurse=False, trace=trace)
  322. if recurse and is_sparse_compressed(t)
  323. else None,
  324. grad=self.describe_tensor(safe_grad(t), trace=trace)
  325. if safe_grad(t) is not None
  326. else None,
  327. creation_meta=torch._C._autograd._get_creation_meta(t)
  328. if t._is_view()
  329. else None,
  330. unwrapped=unwrapped,
  331. level=maybe_get_level(t)
  332. if is_batchedtensor_v or is_gradtrackingtensor_v
  333. else None,
  334. bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
  335. base=self.describe_tensor(t._base, trace=trace)
  336. if recurse and t._is_view() and t._base is not None
  337. else None,
  338. fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
  339. view_func=t._view_func_unsafe,
  340. attrs=attrs,
  341. ctx=ctx,
  342. type=type_v,
  343. # NB: even if functorch is enabled, don't actually save the
  344. # interpreter stack here unless we are actually functorch wrapped;
  345. # it's irrelevant for non-functorch stuff
  346. functorch_stack=maybe_functorch_stack,
  347. autograd_meta_from=autograd_meta_from,
  348. current_level=current_level,
  349. data=t if self.copy_data else None,
  350. )
  351. if trace and r.id not in self.traced_tensors:
  352. trace_structured(
  353. "describe_tensor",
  354. metadata_fn=lambda: r.as_json(self.id),
  355. )
  356. self.traced_tensors.add(r.id)
  357. return r
  358. @dataclass(frozen=True)
  359. class MetaStorageDesc:
  360. id: MetaStorageId
  361. size: int
  362. # NB: this is only populated with copy_data True, it is not directly
  363. # serializable in JSON, you want to do something special here anyway
  364. data: Optional[torch.UntypedStorage]
  365. def as_json(self, describer_id):
  366. return {
  367. "id": self.id,
  368. "describer_id": describer_id,
  369. "size": self.size if isinstance(self.size, int) else repr(self.size),
  370. }
  371. @dataclass(frozen=True)
  372. class MetaTensorDesc:
  373. id: MetaTensorId
  374. ndim: int
  375. dtype: torch.dtype
  376. device: torch.device
  377. # NB: Sometimes, size, stride and storage_offset contain SymInt, in which
  378. # case this is NOT serializable. That only happens when you're
  379. # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we
  380. # can get rid of this use case entirely. Notably, even if we are
  381. # fakeifying a real tensor into a fake tensor with symbolic shapes, the
  382. # size here is NOT dynamic
  383. # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic
  384. # goes through this codepath. But it really should not LOL.
  385. # NB: size could potentially be None as you can override it and make it
  386. # throw an error, but we don't currently have any subclasses that do this
  387. # except C++ nested tensor but we're going to have nested int to make this
  388. # defined on NJT
  389. size: Tuple[int, ...]
  390. dynamo_dynamic_indices: List[int]
  391. layout: torch.layout = torch.strided
  392. is_inference: bool = False
  393. is_leaf: bool = False
  394. requires_grad: bool = False
  395. is_sparse: bool = False
  396. is_mkldnn: bool = False
  397. is_functorch_wrapped: bool = False
  398. is_batchedtensor: bool = False
  399. is_legacy_batchedtensor: bool = False
  400. is_gradtrackingtensor: bool = False
  401. is_view: bool = False
  402. is_nested: bool = False
  403. is_traceable_wrapper_subclass: bool = False
  404. is_functional: bool = False
  405. is_conj: bool = False
  406. is_neg: bool = False
  407. is_parameter: bool = False
  408. stride: Optional[Tuple[int, ...]] = None
  409. storage_offset: int = 0
  410. # NB: We have a choice whether or not to store the id or a direct pointer
  411. # to the data structure. For ease of use, we store the data structure,
  412. # but this means that when we serialize, we have to swizzle these pointers
  413. # back into ids (so we have accurate aliasing relationships)
  414. storage: Optional[MetaStorageDesc] = None
  415. sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed
  416. dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed
  417. is_coalesced: Optional[bool] = None # is_sparse
  418. crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
  419. col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
  420. ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
  421. row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
  422. values: Optional[MetaTensorDesc] = None # is_sparse_compressed
  423. unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped
  424. bdim: Optional[int] = None # is_functorch_wrapped
  425. base: Optional[MetaTensorDesc] = None # is_view
  426. attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
  427. creation_meta: Optional[CreationMeta] = None
  428. grad: Optional[MetaTensorDesc] = None
  429. # Everything below is NOT serializable, need some more work
  430. _UNSERIALIZABLE: ClassVar[List[str]] = [
  431. "ctx",
  432. "type",
  433. "fake_mode",
  434. "view_func",
  435. "level",
  436. "current_level",
  437. "functorch_stack",
  438. "autograd_meta_from",
  439. "data",
  440. ]
  441. ctx: Optional[object] = None # is_traceable_wrapper_subclass
  442. type: Optional[Type] = None # is_traceable_wrapper_subclass
  443. fake_mode: Optional[FakeTensorMode] = None
  444. view_func: Optional[
  445. Callable[
  446. [
  447. torch.Tensor,
  448. Callable[[int], int],
  449. Callable[[torch.Tensor], torch.Tensor],
  450. ],
  451. torch.Tensor,
  452. ]
  453. ] = None
  454. # level looks serializable, but actually it is meaningless without
  455. # the functorch_stack below
  456. level: Optional[int] = None # is_functorch_wrapped
  457. current_level: Optional[int] = None
  458. functorch_stack: Optional[List[CInterpreter]] = None
  459. autograd_meta_from: Optional[torch.Tensor] = None
  460. # This is only populated on copy_data, and typically is not used at all,
  461. # except for some of our meta-ification paths that don't properly use
  462. # storage (pro-tip: you should use storage)
  463. data: Optional[torch.Tensor] = None
  464. # Faithfully serializing functorch tensors will not be too difficult.
  465. # We only need to consider grad/vmap interpreters, and their internal
  466. # state is only bools (mostly what the grad enabled/disabled state
  467. # should be in the lower layer). Beyond that, tensors just need to
  468. # precisely indicate which particular interpreter they correspond
  469. # to (we then replace level with a pointer to the interpreter stack.)
  470. # However, this use of functorch is very "non-lexical" so it's not
  471. # entirely clear how to make it all lexical again, so we haven't done
  472. # it for now.
  473. # NB: This will reference numeric IDs, and it is assumed that you've
  474. # already serialized everything this recursively references
  475. def as_json(self, describer_id):
  476. def json(k, v):
  477. # Some best-effort debugging serialization for unserializable
  478. # fields (feel free to add other special cases as appropriate)
  479. if k in ["data", "autograd_meta_from"]:
  480. return None # never repr these
  481. if k in set(MetaTensorDesc._UNSERIALIZABLE):
  482. return repr(v)
  483. if isinstance(v, (torch.device, torch.dtype, torch.layout)):
  484. return repr(v)
  485. if isinstance(v, torch.SymInt):
  486. return repr(v)
  487. if isinstance(v, (tuple, list)):
  488. return [json(k, v1) for v1 in v]
  489. if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
  490. return v.id
  491. if isinstance(v, CreationMeta):
  492. return str(v)
  493. if k == "attrs" and isinstance(v, dict):
  494. return {k1: v1.id for k1, v1 in v.items()}
  495. return v
  496. r = {
  497. field.name: json(field.name, getattr(self, field.name))
  498. for field in dataclasses.fields(self)
  499. if not (
  500. getattr(self, field.name) is field.default
  501. or (
  502. field.name == "dynamo_dynamic_indices"
  503. and not getattr(self, field.name)
  504. )
  505. )
  506. }
  507. r.update({"describer_id": describer_id})
  508. return r
  509. @property
  510. def shape(self):
  511. return self.size
  512. # A more faithful reproduction would do a copy on the entire
  513. # storage, but this needs to be done carefully because the
  514. # underlying storage could have larger extent than is implied
  515. # by size/stride. The real fix is to properly call
  516. # meta_storage recursively here.
  517. #
  518. # These "safe" functions are intended to be used under no_dispatch() mode.
  519. # The no_dispatch() here is intended to prevent ambient fake tensor mode from
  520. # fakeifying the operation. But if we are given an honest to goodness
  521. # FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
  522. # to do this would be to not use no_dispatch and instead just disable fake
  523. # tensor mode only (allowing for subclass dispatch to occur)
  524. def _safe_copy(dst, src):
  525. if type(src) is not torch.Tensor:
  526. return
  527. dst.copy_(src)
  528. def _safe_clone(src):
  529. if type(src) is not torch.Tensor:
  530. return None
  531. return src.clone()
  532. # This is a class for converting multiple tensors into meta tensors which
  533. # share the same view/storage structure. The operation model is you allocate
  534. # one of these, and then call it repeatedly on all the tensors you want to
  535. # convert. It's important to use the same object for tensors you want to
  536. # share storage because this is how we correlate shared storages to the same
  537. # meta storages. This class will hold weak references to cached tenosrs
  538. # and tensor storages.
  539. class MetaConverter:
  540. def __init__(self, *, copy_data: bool = False):
  541. # Maps MetaStorageId to UntypedStorage
  542. self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  543. # Maps MetaTensorId to torch.Tensor (typically a meta tensor or
  544. # FakeTensor)
  545. self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  546. self.hit = 0
  547. self.miss = 0
  548. self.del_hook = None
  549. self.arg_cnt = 0
  550. # Ensures real_storage/real_tensor are populated on the resulting
  551. # metaified storage/tensor. The naming of this attribute is load
  552. # bearing: FakeTensor relies on real tensor being set to exactly this
  553. # value
  554. self.copy_data = copy_data
  555. self.describer = MetaTensorDescriber(copy_data=copy_data)
  556. def successful(self):
  557. return self.hit > 0 and self.miss == 0
  558. def get_tensor_memo(self, t: MetaTensorDesc):
  559. return self.tensor_memo.get(t.id, None)
  560. def set_tensor_memo(self, t: MetaTensorDesc, v):
  561. self.tensor_memo[t.id] = v
  562. def get_storage_memo(self, s: MetaStorageDesc):
  563. return self.storage_memo.get(s.id, None)
  564. def set_storage_memo(self, s: MetaStorageDesc, v):
  565. self.storage_memo[s.id] = v
  566. def meta_storage(self, s: MetaStorageDesc, callback):
  567. # If we are fakeifying a tensor that has a secretly-zero-sized storage,
  568. # Need to make sure to resize the meta storage too.
  569. if self.get_storage_memo(s) is None:
  570. r_s = callback(
  571. lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
  572. ).untyped_storage()
  573. if self.copy_data:
  574. # NB: no_dispatch is needed because internally storage copy is
  575. # implemented as Tensor operations
  576. with torch.no_grad(), no_dispatch():
  577. assert s.data is not None
  578. r_s.real_storage = s.data.clone()
  579. self.set_storage_memo(s, r_s)
  580. return r_s
  581. else:
  582. return self.get_storage_memo(s)
  583. # This function assumes that it's possible to do the conversion
  584. # NB: name here is used in a conventional way by Dynamo; it corresponds
  585. # precisely to the Source.name() of the tensor we're fakeifying and
  586. # corresponds to a valid Python expression. When we construct sub-names
  587. # as part of this process, we will maintain this invariant! (Even though
  588. # other users of this may not need it this property to be upheld.)
  589. def meta_tensor(
  590. self,
  591. t: MetaTensorDesc,
  592. shape_env: Optional[ShapeEnv] = None,
  593. callback=lambda t: t(),
  594. source: Optional[Source] = None,
  595. symbolic_context: Optional[SymbolicContext] = None,
  596. ):
  597. if source is None:
  598. from torch._dynamo.source import ConstantSource
  599. # TODO: make a dedicated UnknownSource for this?
  600. source = ConstantSource(
  601. f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
  602. )
  603. # This indicates you set no_dispatch() before calling into this
  604. # function. This is an error: we may be creating fake tensors and
  605. # will perform operations on them which need fake tensor mode to
  606. # be active. You will segfault if you are in a no_dispatch() block.
  607. assert not torch._C._dispatch_tls_local_exclude_set().has(
  608. torch._C.DispatchKey.Python
  609. )
  610. arg_cnt = self.arg_cnt
  611. self.arg_cnt += 1
  612. # When we make as_strided calls, we end up generating a guard
  613. # that the new as_strided tensor is in bounds for the old storage
  614. # for the base (since as_strided calls can "bust" out of their
  615. # bounding box.) This guard is unnecessary: if a user is able
  616. # to provide us a tensor with the view base setup this way, we
  617. # don't need to produce a guard, because the fact that they
  618. # were able to produce the view base means its in bounds.
  619. #
  620. # Now, ordinarily, this guard would be harmless. However, the
  621. # generated guard refers to variables bound on the base variable.
  622. # At the moment, Dynamo doesn't actually guard on x._base, because
  623. # according to Voz this results in a lot of spurious invalidations,
  624. # and also if the user doesn't directly make use of _base, its
  625. # pointless anyway (because programs should be parametric over
  626. # whether or not the input tensor is a view or not--unless you're
  627. # mutating the input, but that's a whole 'nother ballgame). So
  628. # for expediency, we suppress these guards so we don't have to
  629. # deal with this (yet, anyway.)
  630. #
  631. # NB: An old version of this code suppressed guards for ALL operations
  632. # happening during meta conversion, not just as_strided calls.
  633. # This is too aggressive: we do duck sizing and 0/1 simplification
  634. # as we allocate variables, and we do need to register guards for
  635. # these cases.
  636. maybe_suppress: Callable[[], Any] = contextlib.nullcontext
  637. if shape_env is not None:
  638. maybe_suppress = shape_env.suppress_guards
  639. def sym_sizes_strides_storage_offset(
  640. t: MetaTensorDesc, src, symbolic_context=symbolic_context
  641. ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
  642. assert t.stride is not None
  643. if shape_env is not None:
  644. fake_mode = t.fake_mode
  645. if fake_mode is not None and fake_mode.shape_env is shape_env:
  646. # Don't reallocate the sizes; the shape envs are the same,
  647. # so reuse the old sizes/strides/etc
  648. return (t.size, t.stride, t.storage_offset)
  649. else:
  650. # TODO: deduplicate this
  651. t_size = tuple(
  652. shape_env._maybe_specialize_sym_int_with_hint(sz)
  653. for sz in t.size
  654. )
  655. t_stride = tuple(
  656. shape_env._maybe_specialize_sym_int_with_hint(sd)
  657. for sd in t.stride
  658. )
  659. t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
  660. t.storage_offset
  661. )
  662. return shape_env._create_symbolic_sizes_strides_storage_offset(
  663. t_size,
  664. t_stride,
  665. t_storage_offset,
  666. [d in t.dynamo_dynamic_indices for d in range(t.ndim)],
  667. src,
  668. symbolic_context=symbolic_context,
  669. )
  670. else:
  671. return (t.size, t.stride, t.storage_offset)
  672. def empty_create(
  673. inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
  674. ):
  675. (
  676. inner_sizes,
  677. inner_strides,
  678. inner_storage_offset,
  679. ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
  680. return torch.empty_strided(
  681. inner_sizes,
  682. inner_strides,
  683. dtype=inner_t.dtype,
  684. device="meta",
  685. )
  686. # Creates a subclass instance with empty inner tensors according to the specified
  687. # symbolic context.
  688. def empty_create_subclass(
  689. t: MetaTensorDesc,
  690. outer_size,
  691. outer_stride,
  692. symbolic_context=symbolic_context,
  693. callback=callback,
  694. source=source,
  695. ):
  696. from torch._dynamo.source import AttrSource
  697. from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
  698. assert t.attrs is not None
  699. assert t.type is not None
  700. # NB: t.ctx could be None if the subclass in question has no
  701. # meaningful context
  702. assert symbolic_context is None or isinstance(
  703. symbolic_context, SubclassSymbolicContext
  704. )
  705. # Note: transform_subclass will use __tensor_unflatten__ to generate
  706. # a fresh subclass wrapper with outer sizes / strides according to the
  707. # outer symbolic context (passed in to this function). Inner size / stride
  708. # / storage offset symbols are allocated according to the appropriate inner
  709. # symbolic contexts, after which the checks in transform_subclass() will
  710. # relate them to the outer metadata as possible.
  711. #
  712. # Morally, the code here is same as transform_subclass, but we've
  713. # written it from scratch to read EmptyCreateSubclass
  714. outer_size = outer_size if outer_size is not None else t.size
  715. outer_stride = outer_stride if outer_stride is not None else t.stride
  716. def transform(attr, inner_t):
  717. r = callback(
  718. lambda: empty_create(
  719. inner_t,
  720. AttrSource(source, attr),
  721. symbolic_context=(
  722. None
  723. if symbolic_context is None
  724. else symbolic_context.inner_contexts[attr]
  725. ),
  726. )
  727. )
  728. if self.copy_data:
  729. with torch.no_grad(), no_dispatch():
  730. r.real_tensor = torch.empty_strided(
  731. inner_t.size,
  732. inner_t.stride,
  733. dtype=inner_t.dtype,
  734. device=inner_t.device,
  735. )
  736. assert inner_t.data is not None
  737. _safe_copy(r.real_tensor, inner_t.data)
  738. return r
  739. transformed_tensors_dict = {
  740. attr: transform(attr, inner_t) for attr, inner_t in t.attrs.items()
  741. }
  742. sub = t.type.__tensor_unflatten__(
  743. transformed_tensors_dict, t.ctx, outer_size, outer_stride
  744. )
  745. # NB: Purposefully guard here to simplify the inner / outer symbols.
  746. # Using sym_eq() for symbolic comparison can result in an expression that's too
  747. # difficult to guard on, so we use == here.
  748. assert sub.shape == outer_size, (
  749. f"Expected return value from {t.type}__tensor_unflatten__() to have "
  750. f"shape equal to {outer_size}, but got: {sub.shape}"
  751. )
  752. assert sub.stride() == outer_stride, (
  753. f"Expected return value from {t.type}__tensor_unflatten__() to have "
  754. f"stride equal to {outer_stride}, but got: {sub.stride()}"
  755. )
  756. return sub
  757. # Returns an all-dynamic symbolic context used for metafying the given tensor with
  758. # fully dynamic dims. This is useful when fake-ifying intermediate tensors in
  759. # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
  760. # don't want to over-specialize during view replay.
  761. def all_dynamic_symbolic_context(
  762. t: MetaTensorDesc, source, shape_env, callback
  763. ):
  764. from torch._dynamo.source import AttrSource
  765. from torch.fx.experimental.symbolic_shapes import (
  766. DimDynamic,
  767. StatelessSymbolicContext,
  768. SubclassSymbolicContext,
  769. )
  770. view_base_context: Optional[SymbolicContext] = None
  771. if t.is_view:
  772. assert t.base is not None
  773. view_base_context = all_dynamic_symbolic_context(
  774. t.base, AttrSource(source, "_base"), shape_env, callback
  775. )
  776. t_symbolic_context: SymbolicContext
  777. t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
  778. if t.is_traceable_wrapper_subclass:
  779. assert t.attrs is not None
  780. inner_contexts: Dict[str, SymbolicContext] = {}
  781. for attr, inner in t.attrs.items():
  782. assert isinstance(attr, str)
  783. inner_contexts[attr] = all_dynamic_symbolic_context(
  784. inner, AttrSource(source, attr), shape_env, callback
  785. )
  786. t_symbolic_context = SubclassSymbolicContext(
  787. dynamic_sizes=t_dynamic_sizes,
  788. constraint_sizes=[None] * t.ndim,
  789. inner_contexts=inner_contexts,
  790. tensor_source=source,
  791. view_base_context=view_base_context,
  792. )
  793. else:
  794. t_symbolic_context = StatelessSymbolicContext(
  795. dynamic_sizes=t_dynamic_sizes,
  796. constraint_sizes=[None] * t.ndim,
  797. view_base_context=view_base_context,
  798. )
  799. return t_symbolic_context
  800. # Returns a fake-ified version of an input view tensor t, given an already fake-ified
  801. # base. At a high level, we want two things:
  802. # 1. fake_t should have the same view relationship to the given fake base as the
  803. # input t has to its _base.
  804. # 2. fake_t should have symbolic sizes / strides / storage offset according to the
  805. # appropriate symbolic context (i.e. from the automatic dynamic algorithm).
  806. #
  807. # We currently take different strategies across view types:
  808. # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
  809. # as_strided() call on the fake-ified base, passing symbolic metadata.
  810. # * For views involving subclasses, perform view replay using view funcs to
  811. # achieve (1). It's necessary for (2) to swap out any closed-over state in
  812. # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
  813. # avoids specialization (and thus over-eager simplification of symbols) that
  814. # could occur during view replay on the fake-ified base.
  815. #
  816. # Examples:
  817. # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
  818. # with an as_strided() call on the fake base passing symbolic metadata.
  819. # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
  820. # is made symbolic to avoid invalid specialization and view replay is then
  821. # done to reconstruct the view.
  822. # * _nested_from_jagged(values, offsets) is a dense -> subclass view
  823. # that returns a subclass instance from a dense values tensor. The offsets
  824. # tensor is closed over in the view func, as it can be considered view metadata.
  825. # First, the offsets tensor is fake-ified according to the inner symbolic
  826. # context and with the correct relationship to the outer size / stride metadata.
  827. # Then view replay is done, swapping in the fake offsets so the view replay output
  828. # is fully fake with no invalid specialization.
  829. def view_from_base(
  830. base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
  831. ):
  832. # fake-ify t's metadata according to the outer symbolic context
  833. (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
  834. t, source
  835. )
  836. if (
  837. not t.is_traceable_wrapper_subclass
  838. and not is_traceable_wrapper_subclass(base)
  839. ):
  840. # Dense -> Dense view case uses as_strided() to construct view relationship.
  841. # TODO: Change this logic to use view replay for consistency?
  842. # It's likely there is no view func available.
  843. with maybe_suppress():
  844. return base.as_strided(sizes, strides, storage_offset)
  845. from torch._dynamo.source import EphemeralSource
  846. from torch.fx.experimental.symbolic_shapes import (
  847. StatelessSymbolicContext,
  848. sym_eq,
  849. )
  850. def symint_visitor_fn(s):
  851. nonlocal symbolic_context
  852. from torch.fx.experimental.symbolic_shapes import DimDynamic
  853. all_static_sizes = (
  854. symbolic_context is not None
  855. and isinstance(symbolic_context, StatelessSymbolicContext)
  856. and all(
  857. x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes
  858. )
  859. )
  860. # Can't just rely on shape env being None - dynamo always initializes it
  861. if all_static_sizes or shape_env is None:
  862. return s
  863. # NB: The symbol here is expected to be simplified out because we a priori
  864. # allocate inner and outer symbols according to the appropriate symbolic
  865. # contexts and prefer those over this symbol during symbol simplification
  866. # (via usage of EphemeralSource below). This -shouldn't- happen, but if
  867. # this symbol somehow leaks out beyond the view tensor's shape metadata, our
  868. # assumption of it being simplified out will fail and it may be guarded on,
  869. # which will hard error.
  870. sym_source = EphemeralSource("symint_visitor_fn")
  871. symbol = shape_env.create_symbol(s, sym_source)
  872. return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
  873. real_to_fake_mapping = {}
  874. if t.is_traceable_wrapper_subclass:
  875. assert t.attrs is not None
  876. # NB: t.ctx could be None if the subclass in question has no
  877. # meaningful context
  878. assert t.type is not None
  879. # Fake-ify t naively here; this is only done so we can get fake-ified inner
  880. # tensors with the correct relationships to the outer sizes / strides for use
  881. # in view replay. It's done beforehand here because it's not easy to do when
  882. # visiting tensors one-by-one during view replay.
  883. #
  884. # Example:
  885. # Consider a Dense -> NJT view. NJT has (values, offsets) components and we
  886. # want a view of values with the offsets closed over. As the offsets component
  887. # is needed to describe the output view, it's important that it's fakeified
  888. # correctly.
  889. fake_t = empty_create_subclass(
  890. t, outer_size=sizes, outer_stride=strides
  891. )
  892. attrs, _ = fake_t.__tensor_flatten__()
  893. for attr in attrs:
  894. real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
  895. def tensor_visitor_fn(
  896. visited_t: torch.Tensor,
  897. # These arguments are never passed, we just use them to close
  898. # over these relevant values
  899. shape_env=shape_env,
  900. callback=callback,
  901. ):
  902. # It's possible to close over an undefined tensor (e.g. NJT's lengths).
  903. if visited_t is None:
  904. return None
  905. # NB: visited_t being a Tensor here is very naughty! Should
  906. # have already been described
  907. # Fake inner tensors of view subclasses will come from the mapping built above.
  908. visited_id = self.describer.get_tensor_id(visited_t)
  909. fake_visited_t = real_to_fake_mapping.get(visited_id, None)
  910. if fake_visited_t is not None:
  911. return fake_visited_t
  912. visited_desc = self.describer.describe_tensor(visited_t)
  913. # For other closed-over tensor state, fake-ify it as all dynamic with an
  914. # ephemeral source. This avoids invalid specialization during view replay.
  915. # If we find that in practice the usage of ephemeral sources isn't enough
  916. # to guarantee that we don't have guards on these symbols, we may need to
  917. # explicitly suppress guards (as is done for _base in the dense -> dense
  918. # view case).
  919. temp_source = EphemeralSource("tensor_visitor_fn")
  920. return self.meta_tensor(
  921. visited_desc,
  922. shape_env,
  923. callback,
  924. source=temp_source,
  925. symbolic_context=all_dynamic_symbolic_context(
  926. visited_desc, temp_source, shape_env, callback
  927. ),
  928. )
  929. # Replay the view, swapping out any non-symbolic SymInts or real tensors
  930. # for symbolic SymInts or fake tensors.
  931. assert t.view_func is not None
  932. # NB: we do NOT suppress guards here, we need to remove ephemeral
  933. # sources
  934. fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
  935. # Ensure the output has symbolic shapes according to the outer symbolic context.
  936. # These checks should simplify out any symbols created for closed-over view func
  937. # SymInts.
  938. torch._check(sym_eq(fake_t.size(), sizes))
  939. torch._check(sym_eq(fake_t.stride(), strides))
  940. torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
  941. return fake_t
  942. if self.get_tensor_memo(t) is None:
  943. GRAD_TENSOR_SENTINEL_VALUE = -2
  944. with torch.inference_mode(t.is_inference):
  945. if t.is_sparse:
  946. is_leaf = t.is_leaf
  947. # The lambda function below is similar to
  948. # `t.to(device='meta')` except the latter
  949. # preserves nnz value
  950. r = callback(
  951. lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
  952. t.sparse_dim,
  953. t.dense_dim,
  954. t.size,
  955. dtype=t.dtype,
  956. layout=torch.sparse_coo,
  957. device="meta",
  958. )
  959. )
  960. if self.copy_data:
  961. # Pray that sparse clone doesn't lose information
  962. assert t.data is not None
  963. with torch.no_grad(), no_dispatch():
  964. r.real_tensor = _safe_clone(t.data)
  965. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  966. # Note [is_coalesced is dispatched]
  967. # Strangely enough, is_coalesced() is a dispatched operator,
  968. # which means that it will get caught by fake tensor mode.
  969. # Ordinarily this would error, but there's some logic in
  970. # fake tensor ensure this doesn't happen.
  971. r._coalesced_(t.is_coalesced)
  972. if t.requires_grad:
  973. r.requires_grad = True
  974. if t.requires_grad and not is_leaf:
  975. # This should probably use DelayedError,
  976. # but clone is fine for now for sparse tensors.
  977. # (DelayedError does not work for sparse because it causes
  978. # the Fake sparse tensor to "lose" its fakeness)
  979. r = r.clone()
  980. with torch.enable_grad():
  981. r._coalesced_(t.is_coalesced)
  982. elif is_sparse_compressed_layout(t.layout):
  983. is_leaf = t.is_leaf
  984. if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
  985. assert t.sparse_dim is not None
  986. assert t.dense_dim is not None
  987. assert t.values is not None
  988. batch_dim = t.ndim - t.sparse_dim - t.dense_dim
  989. blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
  990. else:
  991. blocksize = ()
  992. if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
  993. assert t.crow_indices is not None
  994. index_dtype = t.crow_indices.dtype
  995. else:
  996. assert t.ccol_indices is not None
  997. index_dtype = t.ccol_indices.dtype
  998. r = callback(
  999. lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
  1000. 0,
  1001. t.dense_dim,
  1002. t.shape,
  1003. blocksize,
  1004. index_dtype,
  1005. layout=t.layout,
  1006. dtype=t.dtype,
  1007. device="meta",
  1008. )
  1009. )
  1010. if self.copy_data:
  1011. # Pray sparse clone doesn't lose information
  1012. assert t.data is not None
  1013. with torch.no_grad(), no_dispatch():
  1014. r.real_tensor = _safe_clone(t.data)
  1015. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  1016. if t.requires_grad:
  1017. r.requires_grad = True
  1018. if t.requires_grad and not is_leaf:
  1019. r = torch._C._functions.DelayedError(
  1020. "Internal error: Tried to backward() through example input",
  1021. 1,
  1022. )(r)
  1023. elif t.is_nested and not t.is_traceable_wrapper_subclass:
  1024. # TODO: Handle this better in Dynamo?
  1025. # There are checks there now, but this can still be triggered by a dense
  1026. # tensor graph input that is a view of a strided NT.
  1027. from torch._dynamo.exc import unimplemented
  1028. unimplemented(
  1029. "strided nested tensors are not supported by meta conversion"
  1030. )
  1031. elif t.is_mkldnn:
  1032. is_leaf = t.is_leaf
  1033. sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
  1034. t, source
  1035. )
  1036. # TODO: This doesn't seem right, where's the MKLDNN'ness
  1037. # lol
  1038. r = callback(
  1039. lambda: torch.empty_strided(
  1040. sizes, strides, dtype=t.dtype, device="meta"
  1041. )
  1042. )
  1043. if self.copy_data:
  1044. with torch.no_grad(), no_dispatch():
  1045. assert t.size is not None
  1046. assert t.stride is not None
  1047. r.real_tensor = torch.empty_strided(
  1048. t.size, t.stride, dtype=t.dtype, device=t.device
  1049. )
  1050. assert t.data is not None
  1051. _safe_copy(r.real_tensor, t.data)
  1052. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  1053. if t.requires_grad:
  1054. r.requires_grad = True
  1055. if t.requires_grad and not is_leaf:
  1056. r = torch._C._functions.DelayedError(
  1057. "Internal error: Tried to backward() through example input",
  1058. 1,
  1059. )(r)
  1060. elif t.is_functorch_wrapped:
  1061. if t.is_view:
  1062. from torch._dynamo.exc import unimplemented
  1063. unimplemented(
  1064. "view functorch tensors are not supported by meta conversion"
  1065. )
  1066. # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
  1067. # in a FakeTensor
  1068. def _to_fake_tensor(t: MetaTensorDesc):
  1069. # TODO: why aren't the recursive calls going to
  1070. # meta_tensor
  1071. if t.is_batchedtensor:
  1072. assert t.unwrapped is not None
  1073. assert t.level is not None
  1074. assert t.bdim is not None
  1075. ft = _to_fake_tensor(t.unwrapped)
  1076. lvl = t.level
  1077. bdim = t.bdim
  1078. # You cannot create functorch tensors without
  1079. # having the ambient funtorch interpreter stack
  1080. # available, as the level refers to things in the
  1081. # stack
  1082. with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
  1083. t.functorch_stack
  1084. ):
  1085. r = _add_batch_dim(ft, bdim, lvl)
  1086. elif t.is_gradtrackingtensor:
  1087. assert t.unwrapped is not None
  1088. assert t.level is not None
  1089. disable_functorch = torch._C._DisableFuncTorch
  1090. with disable_functorch():
  1091. ft = _to_fake_tensor(t.unwrapped)
  1092. lvl = t.level
  1093. if lvl == GRAD_TENSOR_SENTINEL_VALUE:
  1094. r = ft
  1095. else:
  1096. with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
  1097. t.functorch_stack
  1098. ):
  1099. r = torch._C._functorch._wrap_for_grad(ft, lvl)
  1100. is_leaf = t.is_leaf
  1101. if t.requires_grad and safe_is_leaf(r):
  1102. r.requires_grad = True
  1103. elif t.requires_grad and not is_leaf:
  1104. r = torch._C._functions.DelayedError( # type: ignore[assignment]
  1105. "Internal error: Tried to backward() through example input",
  1106. 1,
  1107. )(
  1108. r # type: ignore[arg-type]
  1109. )
  1110. elif t.is_functional:
  1111. assert t.unwrapped is not None
  1112. assert t.current_level is not None
  1113. ft = self.meta_tensor(
  1114. t.unwrapped,
  1115. shape_env=shape_env,
  1116. callback=callback,
  1117. # NB: reuse these exactly, we treat the
  1118. # functional tensor as "invisible".
  1119. # TODO: Actually this all probably doesn't
  1120. # work, take a closer look.
  1121. source=source,
  1122. symbolic_context=symbolic_context,
  1123. )
  1124. r = _wrap_functional_tensor(ft, t.current_level)
  1125. # TODO: is_leaf/requires_grad?
  1126. else:
  1127. assert t.stride is not None
  1128. sizes = t.size
  1129. strides = t.stride
  1130. r = callback(
  1131. lambda: torch.empty_strided(
  1132. sizes,
  1133. strides,
  1134. dtype=t.dtype,
  1135. device="meta",
  1136. )
  1137. )
  1138. if self.copy_data:
  1139. with torch.no_grad(), no_dispatch():
  1140. r.real_tensor = torch.empty_strided( # type: ignore[attr-defined]
  1141. t.size,
  1142. t.stride,
  1143. dtype=t.dtype,
  1144. device=t.device,
  1145. )
  1146. assert t.data is not None
  1147. _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
  1148. return r
  1149. r = _to_fake_tensor(t)
  1150. elif t.is_functional and t.device.type not in ["xla", "lazy"]:
  1151. assert t.unwrapped is not None
  1152. assert not t.is_functorch_wrapped # handled above
  1153. unwrapped = self.meta_tensor(
  1154. t.unwrapped,
  1155. shape_env=shape_env,
  1156. callback=callback,
  1157. source=source,
  1158. symbolic_context=symbolic_context,
  1159. )
  1160. r = torch._to_functional_tensor(unwrapped)
  1161. torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
  1162. elif t.is_view:
  1163. # Construct views in two steps: recursively meta-fy their
  1164. # base, and then create view(s) off that. NB: doing it
  1165. # directly from storage is WRONG because this won't cause
  1166. # version counters to get shared.
  1167. assert t.base is not None
  1168. base_symbolic_context = None
  1169. if shape_env and symbolic_context is not None:
  1170. from torch.fx.experimental.symbolic_shapes import (
  1171. StatelessSymbolicContext,
  1172. )
  1173. assert isinstance(symbolic_context, StatelessSymbolicContext)
  1174. # NB: This should generally be set when the input is a view,
  1175. # but the exception right now is for fake-ifying grads, which is
  1176. # a work in progress.
  1177. if symbolic_context.view_base_context is not None:
  1178. base_symbolic_context = symbolic_context.view_base_context
  1179. base = self.meta_tensor(
  1180. t.base,
  1181. shape_env,
  1182. callback,
  1183. source=torch._dynamo.source.AttrSource(source, "_base"),
  1184. symbolic_context=base_symbolic_context,
  1185. )
  1186. def is_c_of_r(complex_dtype, real_dtype):
  1187. return (
  1188. utils.is_complex_dtype(complex_dtype)
  1189. and utils.corresponding_real_dtype(complex_dtype)
  1190. == real_dtype
  1191. )
  1192. # In some situations, MetaConverter may be called in a
  1193. # context where autograd is disabled. For the _is_view
  1194. # assert to pass, we have to setup the autograd view
  1195. # metadata anyway. Do this by reenabling the
  1196. # ADInplaceOrView key. This is kind of a hack.
  1197. old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
  1198. torch._C.DispatchKey.ADInplaceOrView
  1199. )
  1200. torch._C._dispatch_tls_set_dispatch_key_excluded(
  1201. torch._C.DispatchKey.ADInplaceOrView, False
  1202. )
  1203. try:
  1204. if base.dtype == t.dtype:
  1205. pass
  1206. elif is_c_of_r(base.dtype, t.dtype):
  1207. base = torch.view_as_real(base)
  1208. elif is_c_of_r(t.dtype, base.dtype):
  1209. base = torch.view_as_complex(base)
  1210. else:
  1211. # This is not guaranteed to succeed. If it fails, it
  1212. # means there is another dtype-converting view function
  1213. # that hasn't been handled here
  1214. base = base.view(t.dtype)
  1215. # This is very tricky. Naively, you might expect this
  1216. # to hold:
  1217. #
  1218. # if t.requires_grad and not safe_is_leaf(t)
  1219. # assert t._base.requires_grad
  1220. #
  1221. # But it's not true! As you can see in the following
  1222. # program:
  1223. #
  1224. # x = torch.zeros(4)
  1225. # y = x.view(1, 4)
  1226. # y.requires_grad = True
  1227. # z = y.view(1, 1, 4)
  1228. # assert z._base is x
  1229. #
  1230. # So we may have to do *two* views out of the base to
  1231. # recreate this situation.
  1232. if t.is_leaf:
  1233. # Leaf views that track view metadata are created by
  1234. # creating a view inside a no_grad block
  1235. with torch.no_grad():
  1236. r = view_from_base(base, t)
  1237. # As it's a leaf, we can directly assign requires_grad
  1238. r.requires_grad = t.requires_grad
  1239. else:
  1240. if t.base.requires_grad == t.requires_grad:
  1241. # Easy case, just run the view op
  1242. with torch.enable_grad():
  1243. r = view_from_base(base, t)
  1244. # NB: We don't actaully faithfully replicate
  1245. # autograd connectivity, but that doesn't matter
  1246. # today. See following for more info:
  1247. # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
  1248. else:
  1249. # Obscure case. Create a leaf view and give it the
  1250. # correct requires_grad, then do the final view.
  1251. # NB: Can't have a non-leaf without requiring grad!
  1252. assert t.requires_grad
  1253. with torch.no_grad():
  1254. mid = base.view(base.shape)
  1255. mid.requires_grad = t.requires_grad
  1256. with torch.enable_grad():
  1257. r = view_from_base(mid, t)
  1258. # The CreationMeta influences whether or not inplace
  1259. # mutation is an error or not. So we need to make
  1260. # sure we properly propagate this as well.
  1261. assert t.creation_meta is not None
  1262. torch._C._autograd._set_creation_meta(r, t.creation_meta)
  1263. finally:
  1264. torch._C._dispatch_tls_set_dispatch_key_excluded(
  1265. torch._C.DispatchKey.ADInplaceOrView, old_exclude
  1266. )
  1267. else:
  1268. is_leaf = t.is_leaf
  1269. # Graph-Break for wrapped tensors
  1270. if (
  1271. not (t.is_batchedtensor or t.is_gradtrackingtensor)
  1272. and t.is_functorch_wrapped
  1273. ) or t.is_legacy_batchedtensor:
  1274. return NotImplemented
  1275. (
  1276. sizes,
  1277. strides,
  1278. storage_offset,
  1279. ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
  1280. # If we have a subclass that desugars into dense tensors,
  1281. # perform our callback on each inner tensor.
  1282. if t.is_traceable_wrapper_subclass:
  1283. r = empty_create_subclass(
  1284. t, outer_size=sizes, outer_stride=strides
  1285. )
  1286. else:
  1287. r = callback(
  1288. lambda: torch.empty_strided(
  1289. sizes,
  1290. strides,
  1291. dtype=t.dtype,
  1292. device="meta",
  1293. )
  1294. )
  1295. if self.copy_data:
  1296. with torch.no_grad(), no_dispatch():
  1297. assert t.size is not None
  1298. assert t.stride is not None
  1299. r.real_tensor = torch.empty_strided(
  1300. t.size, t.stride, dtype=t.dtype, device=t.device
  1301. )
  1302. _safe_copy(r.real_tensor, t.data)
  1303. assert safe_is_leaf(r), "the callback you passed in doesn't detach"
  1304. if t.requires_grad:
  1305. r.requires_grad = t.requires_grad
  1306. if not is_leaf:
  1307. # Fake up some autograd history.
  1308. # Note: we *used* to call .clone() here to mock up some autograd history.
  1309. # This is bad for subclasses.
  1310. # Consider the case where you have a wrapper subclass that is contiguous,
  1311. # but its inner tensor is noncontiguous().
  1312. # .clone() (or other ops) will have the side effect of changing
  1313. # the metadata of the inner tensor.
  1314. # So instead, we now have a dedicated fn to set autograd history,
  1315. # without inadvertently changing other metadata.
  1316. r = torch._C._functions.DelayedError(
  1317. "Internal error: Tried to backward() through example input",
  1318. 1,
  1319. )(r)
  1320. s = t.storage
  1321. assert s is not None
  1322. if s.id not in self.storage_memo and (
  1323. r.is_nested
  1324. or (
  1325. r.stride() == strides
  1326. and r.storage_offset() == storage_offset
  1327. )
  1328. ):
  1329. # You're normal and happy, install the fresh storage into the memo
  1330. self.set_storage_memo(s, r.untyped_storage())
  1331. if self.copy_data:
  1332. r.untyped_storage().real_storage = (
  1333. r.real_tensor.untyped_storage()
  1334. )
  1335. else:
  1336. # You're in crazy town; somehow you gave us a tensor
  1337. # that wasn't a view, but had nonzero storage offset,
  1338. # nontrivial strides (such that clone() couldn't
  1339. # preserve them), or already aliases with another
  1340. # tensor's storage. The most typical way to end
  1341. # up here is with set_. So use set_ to bludgeon this
  1342. # in.
  1343. r_s = self.meta_storage(s, callback=callback)
  1344. # NB: In principle, this should always work, but there
  1345. # is some subtle difference in the autograd metadata
  1346. # that means we will backprop the set_ call, even if
  1347. # r is declared as an input to grad.
  1348. # See https://github.com/pytorch/pytorch/issues/87956
  1349. # for the reproducer.
  1350. # NB: The in_kernel_invocation_manager here is necessary
  1351. # for fake tensor. If we run the set_ call with fake
  1352. # tensor on, r will improperly report that it is NOT a
  1353. # meta tensor but a cpu tensor, and then the set_ call
  1354. # will fail due to device mismatch. no_dispatch() is
  1355. # not enough, because the fake tensor will still claim
  1356. # to be a CPU tensor and you'll end up in the CPU
  1357. # kernel. Arguably this is a hack; a cleaner way to
  1358. # solve this is to have a FakeStorage concept which
  1359. # would report it's CPU device--no problem now! But
  1360. # this is difficult to do because we don't have storage
  1361. # subclasses. Relevant test is
  1362. # DynamicShapesFunctionTests::test_add_dynamic_shapes in
  1363. # test/dynamo/test_dynamic_shapes.py
  1364. maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
  1365. from torch._subclasses.fake_tensor import (
  1366. in_kernel_invocation_manager,
  1367. maybe_get_fake_mode,
  1368. )
  1369. mb_fake_mode = maybe_get_fake_mode(r)
  1370. if mb_fake_mode is not None:
  1371. maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
  1372. with torch.no_grad(), maybe_suppress():
  1373. with maybe_fake_mgr:
  1374. r.set_(r_s, storage_offset, sizes, strides)
  1375. if self.copy_data:
  1376. with torch.no_grad(), no_dispatch():
  1377. r.real_tensor.set_(
  1378. r_s.real_storage,
  1379. t.storage_offset,
  1380. t.size,
  1381. t.stride,
  1382. )
  1383. if t.grad is not None:
  1384. from torch._dynamo.source import AttrSource
  1385. # TODO: Use a valid grad-specific symbolic context instead of recycling
  1386. # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
  1387. r.grad = self.meta_tensor(
  1388. t.grad,
  1389. shape_env,
  1390. callback,
  1391. source=AttrSource(source, "grad"),
  1392. symbolic_context=symbolic_context,
  1393. )
  1394. torch._C._set_conj(r, t.is_conj)
  1395. torch._C._set_neg(r, t.is_neg)
  1396. # This can be skipped if necessary for performance reasons
  1397. skip_leaf = (
  1398. t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
  1399. )
  1400. assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
  1401. # Thanks to storage resizing, it's possible to end up with a tensor
  1402. # that advertises a real size, but has a storage that actually has zero bytes.
  1403. # Need to reflect this in the generated FakeTensor.
  1404. if t.storage is not None and t.storage.size == 0:
  1405. r.untyped_storage().resize_(0)
  1406. if t.is_parameter:
  1407. r._is_param = True
  1408. self.set_tensor_memo(t, r)
  1409. return self.get_tensor_memo(t)
  1410. def __call__(
  1411. self,
  1412. t,
  1413. shape_env=None,
  1414. *,
  1415. callback=lambda t: t(),
  1416. source=None,
  1417. symbolic_context=None,
  1418. # Controls whether or not we should dump the tensor metadata to structured logs
  1419. # when source is not None. Because we refakify after Dynamo is done,
  1420. # we don't want to dump info again from AOTAutograd, it is redundant.
  1421. trace=True,
  1422. ):
  1423. # TODO: zero tensors? We appear to have eliminated them by
  1424. # excluding complex for now
  1425. # Filter out cases we don't support
  1426. # TODO: This can probably be simplified quite a bit
  1427. if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
  1428. if (
  1429. # Lazy tensors are not supported. Note that XLA is
  1430. # implemented on top of lazy tensor, not excluded here; we
  1431. # have some special handling for it; this is for XLA Dynamo
  1432. # integration
  1433. t.device.type == "lazy"
  1434. or
  1435. # Quantization is not supported
  1436. t.is_quantized
  1437. or
  1438. # Views out of sparse tensors not currently supported (plain
  1439. # sparse is supported htough)
  1440. (t._is_view() and t._base is not None and t._base.is_sparse)
  1441. ):
  1442. self.miss += 1
  1443. return NotImplemented
  1444. else:
  1445. self.hit += 1
  1446. elif torch.overrides.is_tensor_like(t):
  1447. self.miss += 1
  1448. return NotImplemented
  1449. else:
  1450. # non-Tensor types don't count as hit or miss
  1451. return t
  1452. if source is None:
  1453. trace = False
  1454. # Describe the tensor. NB: do NOT disable ambient modes, we may need
  1455. # to query them when figuring out what to put in here
  1456. t_desc = self.describer.describe_tensor(t, trace=trace)
  1457. if trace:
  1458. trace_structured(
  1459. "describe_source",
  1460. metadata_fn=lambda: {
  1461. "describer_id": self.describer.id,
  1462. "id": t_desc.id,
  1463. "source": source.name(),
  1464. },
  1465. )
  1466. # Do the meta-fication. Here, we disable all the ambient modes, to
  1467. # better simulate what would be like to re-fakeify from a fresh
  1468. # process
  1469. with contextlib.ExitStack() as exit_stack:
  1470. exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
  1471. st = peek_interpreter_stack()
  1472. if st is not None:
  1473. exit_stack.enter_context(
  1474. torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
  1475. )
  1476. r = self.meta_tensor(
  1477. t_desc,
  1478. shape_env=shape_env,
  1479. callback=callback,
  1480. source=source,
  1481. symbolic_context=symbolic_context,
  1482. )
  1483. if type(t) is torch.nn.Parameter:
  1484. # NB: Cannot directly use Parameter constructor
  1485. # because that would force a detach, not desirable
  1486. r._is_param = True
  1487. # TODO: return the description for later
  1488. return r
  1489. import torch._prims_common as utils