fake_tensor.py 81 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import logging
  5. import os
  6. import traceback
  7. import weakref
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. from typing import (
  11. Any,
  12. Callable,
  13. cast,
  14. Dict,
  15. List,
  16. Optional,
  17. Tuple,
  18. Type,
  19. TYPE_CHECKING,
  20. TypeVar,
  21. Union,
  22. )
  23. from weakref import ReferenceType
  24. import torch
  25. import torch._custom_op
  26. import torch._logging
  27. from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
  28. from torch._guards import Source
  29. from torch._ops import OpOverload
  30. from torch._prims_common import suggest_memory_format
  31. from torch._subclasses.meta_utils import (
  32. assert_eq,
  33. assert_metadata_eq,
  34. is_sparse_any,
  35. is_sparse_compressed,
  36. MetaConverter,
  37. )
  38. from torch._utils import render_call
  39. from torch.fx.operator_schemas import normalize_function
  40. from torch.multiprocessing.reductions import StorageWeakRef
  41. from torch.overrides import TorchFunctionMode
  42. from torch.utils._mode_utils import no_dispatch
  43. from torch.utils._python_dispatch import (
  44. is_traceable_wrapper_subclass,
  45. TorchDispatchMode,
  46. )
  47. from torch.utils._pytree import PyTree, tree_map, tree_map_
  48. from torch.utils._stats import count
  49. from torch.utils._traceback import CapturedTraceback
  50. if TYPE_CHECKING:
  51. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  52. from torch.types import _bool
  53. class _Unassigned:
  54. pass
  55. def _is_plain_tensor(t):
  56. return (
  57. type(t) is torch.Tensor
  58. and t.layout == torch.strided
  59. and not (
  60. t.is_sparse
  61. or t.is_nested
  62. or is_functorch_wrapped_tensor(t)
  63. or is_legacy_batchedtensor(t)
  64. or torch._is_functional_tensor(t)
  65. )
  66. )
  67. _UNASSIGNED = _Unassigned()
  68. DimList = List
  69. log = logging.getLogger(__name__)
  70. # TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186
  71. # Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105
  72. try:
  73. not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
  74. except ValueError as e:
  75. if "'not_implemented' not registered" in str(e):
  76. import logging as not_implemented_log
  77. else:
  78. raise e
  79. pytree = torch.utils._pytree
  80. T = TypeVar("T")
  81. TensorWeakRef = Any
  82. aten = torch._ops.ops.aten
  83. CONSTANT_NUMEL_LIMIT = 1
  84. RECURSION_COUNT = 0
  85. # Small helper that increments recursion count, and
  86. # resets it when the object goes out of scope. Useful
  87. # if you don't want to increase indentation which is
  88. # what a context manager would do.
  89. class IncrementRecursionCount:
  90. def __init__(self):
  91. global RECURSION_COUNT
  92. RECURSION_COUNT += 1
  93. def __del__(self):
  94. global RECURSION_COUNT
  95. RECURSION_COUNT -= 1
  96. @dataclass
  97. class UnsupportedFakeTensorException(RuntimeError):
  98. reason: str
  99. @dataclass
  100. class DynamicOutputShapeException(RuntimeError):
  101. func: OpOverload
  102. @dataclass
  103. class DataDependentOutputException(RuntimeError):
  104. func: OpOverload
  105. @dataclass
  106. class UnsupportedOperatorException(RuntimeError):
  107. func: OpOverload
  108. def ordered_set(*items):
  109. return dict.fromkeys(items, True)
  110. @contextlib.contextmanager
  111. def unset_fake_temporarily():
  112. old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
  113. try:
  114. yield old
  115. finally:
  116. if old is not None:
  117. torch._C._set_dispatch_mode(old)
  118. def is_fake(x):
  119. if isinstance(x, FakeTensor):
  120. return True
  121. if is_traceable_wrapper_subclass(x):
  122. attrs, _ = type(x).__tensor_flatten__(x)
  123. flattened_tensors = [getattr(x, attr) for attr in attrs]
  124. # need to recurse because we could have nested subclasses
  125. all_fake = all(is_fake(x) for x in flattened_tensors)
  126. any_fake = any(is_fake(x) for x in flattened_tensors)
  127. assert all_fake == any_fake, "got mixed fake and real tensors!"
  128. return all_fake
  129. elif isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
  130. reapply_views = torch._C._functionalization_reapply_views_tls()
  131. unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
  132. return is_fake(unwrapped)
  133. elif isinstance(x, torch.Tensor) and is_functorch_wrapped_tensor(x):
  134. unwrapped = torch._C._functorch.get_unwrapped(x)
  135. return is_fake(unwrapped)
  136. return False
  137. def maybe_get_fake_mode(t):
  138. if isinstance(t, FakeTensor):
  139. return t.fake_mode
  140. if is_traceable_wrapper_subclass(t):
  141. inner_tensor_names, _ = t.__tensor_flatten__()
  142. modes = [
  143. maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names
  144. ]
  145. m = modes[0]
  146. assert all(m is x for x in modes)
  147. return m
  148. elif isinstance(t, torch.Tensor) and torch._is_functional_tensor(t):
  149. reapply_views = torch._C._functionalization_reapply_views_tls()
  150. unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
  151. return maybe_get_fake_mode(unwrapped)
  152. elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t):
  153. unwrapped = torch._C._functorch.get_unwrapped(t)
  154. return maybe_get_fake_mode(unwrapped)
  155. return None
  156. @functools.lru_cache(None)
  157. def get_schema_info(func):
  158. return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
  159. # many of the decompositions registered to torch/_prims do not at the moment model
  160. # aliasing or strides, so as an incremental step, just enable the decompositions in
  161. # torch/_decomp/decompositions.py.
  162. # decomps are used for aot autograd tracing so we would like to unify on their
  163. # implementation and add additional testing to them
  164. @functools.lru_cache(None)
  165. def torch_decomp_decompositions(func):
  166. from torch._decomp import decomposition_table
  167. decompositions = torch._decomp.decompositions
  168. # Note that the function in the decomposition table might be
  169. # different from the one in the module because of the difference
  170. # in out handling in aten API and torch public API
  171. return decomposition_table[func].__module__.startswith(
  172. "torch._decomp"
  173. ) and decomposition_table[func].__name__ in dir(decompositions)
  174. def tree_flatten_only(ty: Type[T], tree: PyTree):
  175. flat_vals = pytree.tree_leaves(tree)
  176. return [elem for elem in flat_vals if isinstance(elem, ty)]
  177. # Similar to `MetaConverter`, this is a class for converting
  178. # multiple tensors into fake tensors which share the same view/storage
  179. # structure. Like `MetaConverter`, it uses `WeakIdRef` to
  180. # hold a weak reference for all memoized tensors.
  181. class FakeTensorConverter:
  182. @property
  183. def tensor_memo(self):
  184. return self.meta_converter.tensor_memo
  185. meta_converter: MetaConverter
  186. constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]]
  187. export: bool
  188. def __init__(self, *, copy_data=False, export=False):
  189. self.meta_converter = MetaConverter(copy_data=copy_data)
  190. self.export = export
  191. # map from to storage to corresponding constant tensors
  192. self.constant_storage_mapping = {}
  193. def add_constant_storage_mapping(self, fake_tensor):
  194. # when you have a constant, aliased tensor:
  195. # const_tensor.add_(torch.rand([1]))
  196. # all aliases of it must become no longer const
  197. assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
  198. weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())
  199. # we need a map from a weak storage to all of its corresponding
  200. # constant tensors. python doesn't have the weak value equivalent
  201. # of defaultdict(list), so we are using a WeakValueDictionary as one
  202. if weak_st not in self.constant_storage_mapping:
  203. self.constant_storage_mapping[weak_st] = []
  204. self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor))
  205. def invalidate_constant_aliases(self, tensor):
  206. assert not isinstance(tensor, FakeTensor)
  207. weak_st = StorageWeakRef(tensor._typed_storage())
  208. if weak_st not in self.constant_storage_mapping:
  209. return
  210. for weak_tensor_ref in self.constant_storage_mapping[weak_st]:
  211. ten = weak_tensor_ref()
  212. if ten is not None:
  213. ten._fix_weakref()
  214. ten.constant = None
  215. del self.constant_storage_mapping[weak_st]
  216. def _get_memo(self, t):
  217. tid = self.meta_converter.describer.lookup_tensor.get(t)
  218. if tid is None:
  219. return None
  220. return self.tensor_memo.get(tid)
  221. def set_tensor_memo(self, t, v):
  222. tid = self.meta_converter.describer.get_tensor_id(t)
  223. self.meta_converter.tensor_memo[tid] = v
  224. # You can have a real tensor that you need to convert into a fake tensor.
  225. # If you have a meta tensor already, call from_meta_and_device.
  226. #
  227. # You're allowed to pass a meta tensor to be turned into a fake
  228. # tensor; although an odd thing to do, this can occur if you're doing
  229. # cross ref testing and the inner test is already operating on meta tensors.
  230. def from_real_tensor(
  231. self,
  232. fake_mode,
  233. t,
  234. make_constant=False,
  235. shape_env=None,
  236. *,
  237. source=None,
  238. symbolic_context=None,
  239. trace=True,
  240. ):
  241. # see note [Tensor Fakification and Symbol Caching]
  242. if not symbolic_context and not source and shape_env:
  243. if tracing_context := torch._guards.TracingContext.try_get():
  244. if t in tracing_context.tensor_to_context:
  245. symbolic_context = tracing_context.tensor_to_context[t]
  246. source = symbolic_context.tensor_source
  247. maybe_memo = self._get_memo(t)
  248. if maybe_memo is not None:
  249. return maybe_memo
  250. existing_device = t.device
  251. # not yet supported in metatensors
  252. if t.is_quantized:
  253. raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
  254. if type(t) is torch.nn.Parameter:
  255. assert not make_constant
  256. def mk_fake_tensor(make_meta_t):
  257. # NB: don't use in_kernel_invocation_manager. to
  258. # ensure FakeTensor can internally do constant computation
  259. # as necessary. Invocation manager is "more correct" as
  260. # it works for more operators in make_meta_t, but
  261. # invariant is that make_meta_t only calls factories
  262. # for which it is not strictly necessary to use the
  263. # invocation manager (I think!)
  264. with no_dispatch():
  265. return FakeTensor(
  266. fake_mode,
  267. make_meta_t(),
  268. existing_device,
  269. # TODO: callback might be used in recursive contexts, in
  270. # which case using t is wrong! BUG!
  271. constant=t if make_constant else None,
  272. )
  273. out = self.meta_converter(
  274. t,
  275. shape_env=shape_env,
  276. callback=mk_fake_tensor,
  277. source=source,
  278. symbolic_context=symbolic_context,
  279. trace=trace,
  280. )
  281. if out is NotImplemented:
  282. raise UnsupportedFakeTensorException("meta converter nyi")
  283. from torch._dynamo.source import RandomValueSource
  284. value = None
  285. if (
  286. not self.export
  287. and _is_plain_tensor(t) # mostly, we want to know if item() works
  288. and t.dim() == 0
  289. and t.device.type == "cpu"
  290. # All integer types are fair game, because signed overflow is UB
  291. # (and even int64 can overflow, since integers in Python are
  292. # arbitrary precision). But only float64 is OK for float, because
  293. # switching between float32 and float64 changes semantics in an
  294. # observable way without hitting UB.
  295. and t.dtype
  296. in [torch.int64, torch.int32, torch.int16, torch.int8, torch.float64]
  297. and source is not None
  298. # Impede setting up item() on things coming from random. These
  299. # are not "real" item() calls, instead UnspecializedPythonVariable
  300. # is unsafely pretending an int is a tensor, which can sometimes
  301. # implicitly cause an item call. The problem is this is pretty
  302. # unsound: there's no reason substituting an int with a Tensor is
  303. # going to give the same results. Today, you mostly get around
  304. # this by typically not having capture_scalar_outputs on and graph
  305. # breaking when someone tries to use the unspec variable in an
  306. # int-y context. But allowing it through here would break that.
  307. # So don't.
  308. #
  309. # Once random values are setup to be represented as
  310. # SymNodeVariable, this condition can be removed. To check if
  311. # you've done it right, this is a good test:
  312. #
  313. # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_reductions.py -k
  314. # TestReductionsCPU.test_dim_reduction_fns_fn_name_amax_cpu_bfloat16
  315. and not isinstance(source, RandomValueSource)
  316. # In Dynamo, shape_env is never none (even with static shapes).
  317. # However, FakeTensorMode can be used by hand and in some cases
  318. # ShapeEnv is not allocated.
  319. and shape_env is not None
  320. ):
  321. from torch._dynamo.source import CallMethodItemSource, FloatTensorSource
  322. from torch.fx.experimental.symbolic_shapes import DimDynamic
  323. with no_dispatch():
  324. value = t.item()
  325. # Peephole strip out unnecessary torch.as_tensor(x).item()
  326. if isinstance(source, FloatTensorSource):
  327. item_source = source.base
  328. else:
  329. item_source = CallMethodItemSource(source)
  330. symbol = shape_env.create_unspecified_symbol(
  331. value,
  332. source=item_source,
  333. dynamic_dim=DimDynamic.DYNAMIC,
  334. )
  335. # NB: reusing item_memo here ensures that we invalidate on
  336. # mutation
  337. if t.dtype == torch.int64:
  338. out.item_memo = shape_env.create_symintnode(
  339. symbol,
  340. hint=value,
  341. source=item_source,
  342. )
  343. elif t.dtype == torch.float64:
  344. out.item_memo = shape_env.create_symfloatnode(
  345. symbol,
  346. hint=value,
  347. source=item_source,
  348. )
  349. if make_constant:
  350. self.add_constant_storage_mapping(out)
  351. # NB: meta_converter set the memo
  352. return out
  353. # If you specify the device, it MUST be a meta tensor.
  354. def from_meta_and_device(self, fake_mode, t, device):
  355. assert (
  356. t.device.type == "meta"
  357. ), f"tensor's device must be `meta`, got {t.device.type} instead"
  358. # This is a bit abusive (this is not the "real" tensor) but whatever,
  359. # the meta tensor should be fresh so there's no way to get it wrong
  360. maybe_memo = self._get_memo(t)
  361. if maybe_memo is not None:
  362. return maybe_memo
  363. out = FakeTensor(fake_mode, t, device)
  364. self.set_tensor_memo(t, out)
  365. return out
  366. @functools.lru_cache(None)
  367. def init_cuda_context():
  368. # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first
  369. if torch.cuda.is_available():
  370. torch.empty(1, device="cuda") if torch.version.hip is None else torch.zeros(
  371. 1, device="cuda"
  372. )
  373. @contextlib.contextmanager
  374. def in_kernel_invocation_manager(fake_mode):
  375. # See: note [Fake Tensor Dispatch Keys]
  376. prev_in_kernel = fake_mode.in_kernel_invocation
  377. meta_in_tls = torch._C._meta_in_tls_dispatch_include()
  378. assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
  379. with torch._C._DisableTorchDispatch():
  380. fake_mode.in_kernel_invocation = True
  381. # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave
  382. # `Dense` turned on (because it's implied by `Meta`)
  383. with torch._C._PreserveDispatchKeyGuard():
  384. torch._C._set_meta_in_tls_dispatch_include(True)
  385. try:
  386. yield
  387. finally:
  388. fake_mode.in_kernel_invocation = prev_in_kernel
  389. # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
  390. # Return if the function allows Python numbers to bind to Tensors
  391. def should_allow_numbers_as_tensors(func: OpOverload):
  392. return torch._C._should_allow_numbers_as_tensors(
  393. func.name().split("::")[-1].split(".")[0]
  394. )
  395. class FakeTensorConfig:
  396. debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1"
  397. # This memorizes the unbacked SymInt representing quantities like the number
  398. # of nonzero elements in this tensor. There is one instance of the descriptor
  399. # per particular quantity to memoize.
  400. #
  401. # Memoization is helpful if you do something like x[mask] and y[mask];
  402. # mask.nonzero() gets repeatedly called and should give a consistent unbacked
  403. # SymInt. It needs to be invalidated in the same way constant is.
  404. #
  405. # Making this a descriptor may seem overly fancy, but actually it's the most
  406. # convenient way to make sure we have access to FakeTensor during access,
  407. # which is required for testing version counter and epoch validity
  408. class UnbackedMemoDescriptor:
  409. _name: str
  410. def __set_name__(self, owner, name):
  411. self._name = name
  412. def _memo(self, obj):
  413. return f"_{self._name}"
  414. def _memo_vc(self, obj):
  415. return f"_{self._name}_vc"
  416. # When we retrace, we need to invalidate all the memos so that we can
  417. # accurately identify the first time unbacked SymInts are allocated.
  418. # This is only relevant for inputs; for intermediates, they will get fresh
  419. # fake tensors so you won't have a memo anyway
  420. def _memo_epoch(self, obj):
  421. return f"_{self._name}_epoch"
  422. def __get__(self, obj: "FakeTensor", objtype=None):
  423. if (r := getattr(obj, self._memo(obj))) is None:
  424. return None
  425. # Version counter based tracking isn't 100% sound but it's close
  426. # enough
  427. if (
  428. getattr(obj, self._memo_vc(obj)) != obj._version
  429. or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
  430. ):
  431. setattr(obj, self._memo(obj), None)
  432. return None
  433. return r
  434. def __set__(self, obj, value):
  435. if value is None:
  436. setattr(obj, self._memo(obj), None)
  437. setattr(obj, self._memo_vc(obj), None)
  438. setattr(obj, self._memo_epoch(obj), None)
  439. elif not torch.is_inference_mode_enabled():
  440. setattr(obj, self._memo(obj), value)
  441. setattr(obj, self._memo_vc(obj), obj._version)
  442. setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch)
  443. class FakeTensor(torch.Tensor):
  444. """
  445. Meta tensors give you the ability to run PyTorch code without having to
  446. actually do computation through tensors allocated on a `meta` device.
  447. Because the device is `meta`, meta tensors do not model device propagation.
  448. FakeTensor extends MetaTensors to also carry an additional `fake_device`
  449. which tracks devices that would have been used.
  450. """
  451. fake_device: torch.device
  452. fake_mode: "FakeTensorMode"
  453. constant: Optional[torch.Tensor]
  454. real_tensor: Optional[torch.Tensor]
  455. # TODO: Generalize this as needed, e.g., into a trie of memos, if
  456. # you do something like x[0].item() (x[0] is fresh each time, so
  457. # memo mechanism here won't work)
  458. nonzero_memo = UnbackedMemoDescriptor()
  459. item_memo = UnbackedMemoDescriptor()
  460. unique_memo = UnbackedMemoDescriptor()
  461. # Indicates to our torch_dispatch dispatching infra that
  462. # this is an "infra" mode with lower dispatching precedence.
  463. _mode_key = torch._C._TorchDispatchModeKey.FAKE
  464. @property
  465. def device(self):
  466. if self.fake_mode.in_kernel_invocation:
  467. return torch.device("meta")
  468. else:
  469. return self.fake_device
  470. # Note: [Fake Tensor Dispatch Keys]
  471. # In order to model the behavior of device-specific autocast
  472. # and autograd logic, we update the dispatch keys of FakeTensors
  473. # to reflect their fake device. This includes the BackendComponent
  474. # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent
  475. # related Autocast and Autograd keys. __torch__dispatch__ sits below
  476. # Autocast and Autograd, and is only invoked when we are at the
  477. # kernel for the BackendComponent. Then, we add Meta to the
  478. # thread-local dispatch include set to hit the meta kernel
  479. # instead of the kernel of the BackendComponent for the fake device.
  480. # The `device_for_backend_keys` does that below
  481. # NOTE: this probably will not do the right thing for backends
  482. # that have dispatch keys which are higher than the "meta" key:
  483. # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189
  484. # We don't support named tensors; graph break
  485. @property
  486. def names(self):
  487. raise UnsupportedFakeTensorException(
  488. "torch.compile doesn't support named tensors"
  489. )
  490. @staticmethod
  491. def __new__(cls, fake_mode, elem, device, constant=None, real_tensor=None):
  492. self = torch.Tensor._make_subclass(
  493. cls,
  494. elem,
  495. elem.requires_grad,
  496. dispatch_device=True,
  497. device_for_backend_keys=device,
  498. )
  499. if not fake_mode._allow_unsafe_data_ptr_access:
  500. torch._C._set_throw_on_mutable_data_ptr(self)
  501. else:
  502. torch._C._set_warn_deprecated_on_mutable_data_ptr(self)
  503. assert elem.device.type == "meta", elem.device.type
  504. device = device if isinstance(device, torch.device) else torch.device(device)
  505. # NB: it is fine, if a little confusing, for device to be meta
  506. # (we are faking a meta tensor in that case). However, it often
  507. # indicates some sort of confusion (e.g., you accidentally passed
  508. # in a meta tensor when you should have passed in the real tensor).
  509. # So by default we disallow meta, and if you are working in a situation
  510. # where it is helpful (e.g., crossref testing) you can turn it back
  511. # on
  512. if not fake_mode.allow_meta:
  513. assert device.type != "meta"
  514. # normalize device.
  515. if device.type == "cuda":
  516. init_cuda_context()
  517. if (
  518. device.type
  519. in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()]
  520. and device.index is None
  521. ):
  522. if getattr(torch, device.type).is_initialized():
  523. device = torch.device(
  524. f"{device.type}:{getattr(torch, device.type).current_device()}"
  525. )
  526. else:
  527. device = torch.device(f"{device.type}:0")
  528. self.fake_device = device # type: ignore[attr-defined]
  529. self.fake_mode = fake_mode # type: ignore[attr-defined]
  530. self.constant = constant # type: ignore[attr-defined]
  531. assert not isinstance(real_tensor, FakeTensor)
  532. self.real_tensor = real_tensor # type: ignore[attr-defined]
  533. self.nonzero_memo = None
  534. self.item_memo = None
  535. self.unique_memo = None
  536. if FakeTensorConfig.debug:
  537. self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
  538. return self
  539. # In some circumstances, a conventional torch.Tensor constructor
  540. # will get rewritten to call into FakeTensor. We must provide an
  541. # __init__ method that can accept the Python interpreters initialization
  542. # in such a situation; we must also be able to handle direct fake
  543. # tensor construction via FakeTensor().
  544. #
  545. # In particular, the __init__ call will look funny in the following case:
  546. #
  547. # with FakeTensorMode():
  548. # x = torch.Tensor([1, 2, 3])
  549. #
  550. # this desugars into:
  551. #
  552. # with FakeTensorMode():
  553. # x = torch.Tensor.__new__([1, 2, 3])
  554. # # NB: x is a fake tensor, because of the mode!
  555. # x.__init__([1, 2, 3]) # not the normal fake tensor args!
  556. #
  557. def __init__(self, *args, **kwargs):
  558. super().__init__()
  559. @staticmethod
  560. def from_tensor(t, fake_mode):
  561. return fake_mode.from_tensor(t)
  562. @classmethod
  563. @count
  564. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  565. # need to handle here to avoid infinite recursion
  566. # see [in_kernel_invocation]
  567. if func == torch.ops.prim.device.default:
  568. assert len(args) == 1 and isinstance(args[0], FakeTensor)
  569. if args[0].fake_mode.in_kernel_invocation:
  570. return torch.device("meta")
  571. else:
  572. return args[0].fake_device
  573. # this handler must be done inside FakeTensor subclass, not mode, because
  574. # we can end up dispatching here when we have a fake tensor with
  575. # symbolic sizes running under in_kernel_invocation_manager.
  576. # The subclass is asked to handle this query because size (not
  577. # sym_size) was called, but we are unable to serve it directly because
  578. # there are symbolic sizes in the class. The use of
  579. # in_kernel_invocation_manager means it's incorrect to activate a
  580. # mode to actually handle this (this caused
  581. # https://github.com/pytorch/pytorch/issues/122772).
  582. if handler := _DISPATCH_META_HANDLERS.get(func):
  583. return handler(args)
  584. # Because fake mode can return NotImplemented (if it sees a subclass
  585. # it doesn't know how to deal with), this test here is important
  586. # because the next dispatch after a fake mode will attempt to use
  587. # subclasses of tensors to dispatch, and any FakeTensor arguments
  588. # will be considered eligible.
  589. unrecognized_types = [
  590. t for t in types if not issubclass(t, FakeTensor) and t is not torch.Tensor
  591. ]
  592. if unrecognized_types:
  593. not_implemented_log.debug(
  594. "FakeTensor unrecognized subclass(es): %s", unrecognized_types
  595. )
  596. return NotImplemented
  597. fake_mode = None
  598. for arg in pytree.arg_tree_leaves(*args, **kwargs):
  599. if isinstance(arg, FakeTensor):
  600. fake_mode = arg.fake_mode
  601. break
  602. assert fake_mode is not None
  603. # If the fake mode is already active, don't try to reapply it!
  604. # NotImplemented is the right thing to return here, because the
  605. # typical situation this can occur is if ProxyTensorMode returned a
  606. # NotImplemented because of a not implemented subclass; we may have
  607. # unluckily attempted to hit FakeTensor's dispatch first,
  608. # NotImplemented lets us keep chaining until we find the actual
  609. # subclass
  610. maybe_cur_fake_mode = torch._C._get_dispatch_mode(
  611. torch._C._TorchDispatchModeKey.FAKE
  612. )
  613. if maybe_cur_fake_mode:
  614. not_implemented_log.debug(
  615. "FakeTensor mode already active: %s in %s",
  616. fake_mode,
  617. maybe_cur_fake_mode,
  618. )
  619. return NotImplemented
  620. assert not fake_mode.in_kernel_invocation
  621. with fake_mode: # type: ignore[attr-defined]
  622. return func(*args, **kwargs)
  623. @staticmethod
  624. def _find_common_device(func, flat_args) -> Tuple[torch.device, bool]:
  625. # Returns: (common_device, has_scalar_only_inputs)
  626. # cpu - zero-dim tensors can be called in cuda kernels,
  627. # so overwrite the common_device if it the only existing
  628. # device comes from a cpu zero-dim tensor
  629. common_device = None
  630. has_scalar_only_inputs = False
  631. is_cpu_zero_dim = None
  632. def cpu_zero_dim(t):
  633. return t.device.type == "cpu" and t.dim() == 0
  634. def merge_devices(t):
  635. nonlocal common_device
  636. nonlocal is_cpu_zero_dim
  637. if not isinstance(t, FakeTensor):
  638. return
  639. if common_device is None:
  640. common_device = t.device
  641. is_cpu_zero_dim = cpu_zero_dim(t)
  642. return
  643. t_is_cpu_zero_dim = cpu_zero_dim(t)
  644. if t.device == common_device:
  645. if is_cpu_zero_dim:
  646. is_cpu_zero_dim = t_is_cpu_zero_dim
  647. return
  648. # mismatching devices !
  649. # if current tensor is cpu 0 dim, defer to existing device
  650. if t_is_cpu_zero_dim:
  651. return
  652. # current device is from cpu 0 dim tensor, overwrite
  653. if is_cpu_zero_dim:
  654. common_device = t.device
  655. is_cpu_zero_dim = t_is_cpu_zero_dim
  656. return
  657. # mismatching devices of non-zero dim tensors, throw
  658. # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
  659. raise RuntimeError(
  660. f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
  661. )
  662. for arg in flat_args:
  663. merge_devices(arg)
  664. # some functions that allow Python numbers to bind to Tensors
  665. # if we have failed to find a device, and we're running one of these operators,
  666. # we must have scalar only inputs
  667. if should_allow_numbers_as_tensors(func) and common_device is None:
  668. # ops with scalar only inputs always have result on cpu
  669. has_scalar_only_inputs = True
  670. common_device = torch.device("cpu")
  671. assert common_device is not None, f"Could not find common device for {func}"
  672. return common_device, has_scalar_only_inputs
  673. # We must handle tolist in a special way for FakeTensors here in the case
  674. # where tolist is called from torch dispatch for tensor subclasses.
  675. # Ordinarily, if a program calls .tolist compiling still works because there is
  676. # special handling in dynamo, but for tensor subclasses if .tolist is called
  677. # inside torch dispatch, the .tolist call may be directly on a FakeTensor.
  678. # This would result in an error since wrapper subclasses don't have storage.
  679. # To avoid this, we handle the FakeTensor case by (1) specializing on the size
  680. # of the tensor to create the output Python list, and (2) creating unbacked
  681. # symints for each element of the list.
  682. def tolist(self):
  683. assert self.dim() == 1, "NYI for higher dims"
  684. shape_env = self.fake_mode.shape_env
  685. out = []
  686. # Specialize on the length of the list
  687. for _ in range(self.shape[0]):
  688. s = shape_env.create_unbacked_symint()
  689. # max value?
  690. torch._check_is_size(s)
  691. torch._check(s >= 2)
  692. out.append(s)
  693. return out
  694. @dataclass(frozen=True)
  695. class TensorMetadata:
  696. """
  697. The Tensor metadata relevant to hashing FakeTensors when caching.
  698. """
  699. dtype: torch.dtype
  700. shape: torch.Size
  701. stride: Tuple[Any, ...]
  702. device: torch.device
  703. layout: torch.layout
  704. memory_format: Optional[torch.memory_format]
  705. storage_offset: int
  706. storage_bytes: Optional[int]
  707. requires_grad: bool
  708. is_quantized: bool
  709. is_conj: bool
  710. is_neg: bool
  711. is_inference: bool
  712. is_sparse: bool # read: is sparse COO
  713. is_coalesced: Optional[bool]
  714. dense_dim: Optional[int]
  715. sparse_dim: Optional[int]
  716. def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata":
  717. """
  718. Extract the TensorMetadata of a tensor.
  719. """
  720. memory_format: Optional[torch.memory_format] = suggest_memory_format(t)
  721. if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format):
  722. memory_format = None
  723. return TensorMetadata(
  724. dtype=t.dtype,
  725. shape=t.shape,
  726. stride=t.stride() if t.layout == torch.strided else (),
  727. device=t.device,
  728. layout=t.layout,
  729. memory_format=memory_format,
  730. storage_offset=t.storage_offset(),
  731. # Only set storage_bytes for tensors that have storage (not sparse)
  732. storage_bytes=t.untyped_storage().nbytes() if not t.is_sparse else None,
  733. requires_grad=t.requires_grad,
  734. is_quantized=t.is_quantized,
  735. is_conj=t.is_conj(),
  736. is_neg=t.is_neg(),
  737. is_inference=t.is_inference(),
  738. is_sparse=t.is_sparse,
  739. is_coalesced=t.is_coalesced() if t.is_sparse else None,
  740. dense_dim=t.dense_dim() if t.is_sparse else None,
  741. sparse_dim=t.sparse_dim() if t.is_sparse else None,
  742. )
  743. class _DispatchCacheKey(list):
  744. """
  745. Key for the FakeTensor dispatch cache. Inspired by (copied from)
  746. _HashedSeq from the functools.lru_cache implementation.
  747. """
  748. __slots__ = "hashvalue" # noqa: PLC0205
  749. def __init__(self, tup, hash=hash):
  750. self[:] = tup
  751. self.hashvalue = hash(tup)
  752. def __hash__(self):
  753. return self.hashvalue
  754. @dataclass(frozen=True)
  755. class _DispatchCacheEntry:
  756. """
  757. Entry type for the FakeTensor dispatch cache. Accounts for two possibilities:
  758. 1) The op is inplace, and a hit means we need to alias the argument at a given
  759. index. 2) We need to synthesize a new FakeTensor given tensor metadata. For view
  760. ops, we further capture the index of the arg to alias.
  761. """
  762. inplace_idx: Optional[int] = None
  763. metadata: Optional[TensorMetadata] = None
  764. view_idx: Optional[int] = None
  765. @dataclass(frozen=True)
  766. class _BypassDispatchCache(Exception):
  767. """
  768. Signals cases that should skip FakeTensor caching.
  769. """
  770. reason: str
  771. @dataclass(frozen=True)
  772. class DispatchCacheInfo:
  773. """
  774. Information about the state of the FakeTensor dispatch cache.
  775. """
  776. hits: int
  777. misses: int
  778. bypasses: Dict[str, int]
  779. size: int
  780. # We keep one instantiation of `fake_tensor_converter` active
  781. # for the duration of `with FakeTensorMode()`.
  782. # This allows accurate storage aliasing across invocation of
  783. # different operators. While this will keep all freshly allocated
  784. # tensors alive during `FakeTensorMode`, there will no be no
  785. # new allocations of Tensors which have non-meta storage so
  786. # memory should not significantly increase.
  787. class FakeTensorMode(TorchDispatchMode):
  788. cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
  789. cache_hits: int = 0
  790. cache_misses: int = 0
  791. cache_bypasses: Dict[str, int] = defaultdict(int)
  792. # Every time you retrace using the same fake tensor mode, you should
  793. # advance the epoch so we don't reuse unbacked memos
  794. epoch: int = 0
  795. in_kernel_invocation: bool = False
  796. def __init__(
  797. self,
  798. *,
  799. allow_fallback_kernels=True,
  800. allow_non_fake_inputs=False,
  801. shape_env=None,
  802. static_shapes=None,
  803. # TODO: This is a temporary measure, see
  804. # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748
  805. # We're currently solely using this to impede population of
  806. # item_memo for 0d scalar tensor inputs when export, because this
  807. # causes things that used to be deferred runtime asserts to turn into
  808. # guards, and then the guards are just lost. We can potentially fix
  809. # this by ensuring guards also get put in the graph, but this is
  810. # pending a rework of how deferred runtime asserts in export. Once
  811. # that's done, we can remove this.
  812. export=False,
  813. ):
  814. log.debug("create_mode 0x%x", id(self))
  815. self.allow_fallback_kernels = allow_fallback_kernels
  816. import torch._dynamo.config
  817. import torch._functorch.config
  818. self.propagate_real_tensors = (
  819. torch._functorch.config.fake_tensor_propagate_real_tensors
  820. )
  821. self.fake_tensor_converter = FakeTensorConverter(
  822. copy_data=self.propagate_real_tensors,
  823. export=export,
  824. )
  825. if static_shapes is not None:
  826. self.static_shapes = static_shapes
  827. else:
  828. self.static_shapes = shape_env is None
  829. # This is temporarily patched to True in Dynamo to grandfather in some
  830. # places where we unconditionally allow scalar outputs, TO BE REMOVED
  831. self.allow_scalar_outputs = False
  832. self._allow_unsafe_data_ptr_access = (
  833. torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access
  834. )
  835. self.allow_meta = torch._functorch.config.fake_tensor_allow_meta
  836. self.cache_enabled = (
  837. torch._dynamo.config.fake_tensor_cache_enabled
  838. and not self.propagate_real_tensors
  839. )
  840. self.cache_crosscheck_enabled = (
  841. torch._dynamo.config.fake_tensor_cache_crosscheck_enabled
  842. )
  843. # A flag that controls, whether we want to invoke ops on mix of
  844. # real weights/global variables and fake inputs
  845. self.allow_non_fake_inputs = allow_non_fake_inputs
  846. # [in_kernel_invocation]
  847. # when FakeTensor is invoked in user code, .device should return
  848. # the fake_device of the tensor so that code such as as `if x.is_cuda`
  849. # or torch.zeros([10, 10], device=x.device) continues to execute as if
  850. # the FakeTensor were real. However, within kernel execution, we return
  851. # the `Meta` device because all computation within the kernels should
  852. # behave as if the Tensors are on meta devices. Kernels should allocate
  853. # new tensors on meta devices, and checks like `is_meta` should return true.
  854. # within python refs, we always return the real device by defining
  855. # the device property
  856. self.in_kernel_invocation = False
  857. # True if we enter'ed and actually enabled fake tensor mode,
  858. # false if it was a no-op. Not thread safe but neither is
  859. # in_kernel_invocation
  860. # If another fake mode was already active when we enter, we also stash it here.
  861. # That way when we exit, we know to re-enable the previous fake mode.
  862. self.enter_stack: List[
  863. Tuple[bool, Optional[TorchDispatchMode], Optional[_bool]]
  864. ] = []
  865. self.shape_env: ShapeEnv = shape_env
  866. self._stack_trace = traceback.extract_stack()
  867. self._stack = None
  868. # Indicates to our torch_dispatch dispatching infra that
  869. # this is an "infra" mode with lower dispatching precedence.
  870. self._mode_key = torch._C._TorchDispatchModeKey.FAKE
  871. # Typically, there is only one fake tensor mode and you test for it by
  872. # doing an isinstance test. However, in some situations, there might be
  873. # TWO fake tensor modes. The canonical example of this is exporting
  874. # a fake model: there is an outer fake mode created by the user, and
  875. # an inner fake mode created by Dynamo. The two phase process is required
  876. # because the outer fake mode typically won't have a ShapeEnv, even if
  877. # the user is interested in exporting with dynamic shapes (so the inner
  878. # fake mode will actually have a ShapeEnv and swap in symbolic sizes.)
  879. #
  880. # In this case, it's insufficient to test only one FakeTensor: you need
  881. # to distinguish between our fake tensor and other fake tensors. That's
  882. # what this function does.
  883. def is_our_fake(self, t):
  884. return isinstance(t, FakeTensor) and t.fake_mode is self
  885. # If we should avoid device init. This changes the behavior of various APIs:
  886. # - We avoid constant-prop on Tensors with ops that move them to another device
  887. # - We change the torch.tensor ctor contract to never materialize
  888. # tensors on device
  889. # (see NOTE: [torch.tensor, lift_fresh, and device movement])
  890. @property
  891. def avoid_device_init(self):
  892. return not torch.cuda.is_available()
  893. @property
  894. def stack(self):
  895. if self._stack is None:
  896. self._stack = "".join(traceback.format_list(self._stack_trace))
  897. return self._stack
  898. @count
  899. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  900. # FakeTensorMode should not be set when we're inside of it.
  901. assert (
  902. torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
  903. ), func
  904. try:
  905. return self.dispatch(func, types, args, kwargs)
  906. except TypeError:
  907. log.exception("fake tensor raised TypeError")
  908. raise
  909. # No-op if FakeTensorMode is already in use
  910. def __enter__(self):
  911. prev_only_lift_cpu_tensors = None
  912. if self.avoid_device_init:
  913. # See NOTE: [torch.tensor, lift_fresh, and device movement]
  914. prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
  915. torch._C._set_only_lift_cpu_tensors(True)
  916. maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
  917. if self is not maybe_prev_fake_mode:
  918. self.enter_stack.append(
  919. (True, maybe_prev_fake_mode, prev_only_lift_cpu_tensors)
  920. )
  921. return super().__enter__()
  922. else:
  923. # no-op (still need to re-set the fake mode though since we unset it)
  924. torch._C._set_dispatch_mode(self)
  925. self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
  926. return self
  927. def __exit__(self, a, b, c):
  928. (
  929. live,
  930. maybe_prev_fake_mode,
  931. maybe_prev_only_lift_cpu_tensors,
  932. ) = self.enter_stack.pop()
  933. if live:
  934. out = super().__exit__(a, b, c)
  935. # Re-enable the previous fake mode, if there was one.
  936. if maybe_prev_fake_mode is not None:
  937. torch._C._set_dispatch_mode(maybe_prev_fake_mode)
  938. if maybe_prev_only_lift_cpu_tensors is not None:
  939. torch._C._set_only_lift_cpu_tensors(maybe_prev_only_lift_cpu_tensors)
  940. @classmethod
  941. def cache_info(cls) -> DispatchCacheInfo:
  942. """
  943. Query the state of the dispatch cache.
  944. """
  945. return DispatchCacheInfo(
  946. FakeTensorMode.cache_hits,
  947. FakeTensorMode.cache_misses,
  948. dict(FakeTensorMode.cache_bypasses),
  949. len(FakeTensorMode.cache),
  950. )
  951. @classmethod
  952. def cache_clear(cls):
  953. """
  954. Clear the dispatch cache.
  955. """
  956. cls.cache_hits = 0
  957. cls.cache_misses = 0
  958. cls.cache_bypasses.clear()
  959. cls.cache.clear()
  960. def _cached_dispatch_impl(
  961. self,
  962. func: OpOverload,
  963. types: Tuple[Any, ...],
  964. args: Tuple[Any, ...],
  965. kwargs: Dict[str, Any],
  966. ):
  967. """
  968. Lookup a cache entry for the given arguments. If none exists, dispatch
  969. and cache the result (if the result is eligible for caching).
  970. """
  971. output: Union[FakeTensor, _Unassigned] = _UNASSIGNED
  972. try:
  973. key = self._cache_key(func, args, kwargs)
  974. entry = FakeTensorMode.cache.get(key, None)
  975. if entry is not None:
  976. output = self._output_from_cache_entry(entry, func, args)
  977. FakeTensorMode.cache_hits += 1
  978. if self.cache_crosscheck_enabled:
  979. # For debugging / testing: Validate that the output synthesized
  980. # from the cache matches the output created by normal dispatch.
  981. self._crosscheck_cache_output(output, func, types, args, kwargs)
  982. else:
  983. self._validate_cache_key(func, args, kwargs)
  984. output = self._dispatch_impl(func, types, args, kwargs)
  985. entry = self._make_cache_entry(key, func, args, kwargs, output)
  986. FakeTensorMode.cache[key] = entry
  987. FakeTensorMode.cache_misses += 1
  988. except _BypassDispatchCache as e:
  989. FakeTensorMode.cache_bypasses[e.reason] += 1
  990. if output is _UNASSIGNED:
  991. output = self._dispatch_impl(func, types, args, kwargs)
  992. return output
  993. def _cache_key(
  994. self,
  995. func: OpOverload,
  996. args: Tuple[Any, ...],
  997. kwargs: Dict[str, Any],
  998. ) -> _DispatchCacheKey:
  999. """
  1000. Create a cache key given the dispatch args. Raises _BypassDispatchCache
  1001. for any situation that precludes caching.
  1002. """
  1003. key_values = (
  1004. func,
  1005. # Translate any FakeTensor args to metadata.
  1006. self._prep_args_for_hash(args) if args else (),
  1007. self._prep_args_for_hash(kwargs) if kwargs else (),
  1008. # Capture the default_dtype mode since that can affect the output tensor,
  1009. # e.g., when operating on constant float values.
  1010. torch.get_default_dtype(),
  1011. # Capture the current device to support, e.g., cache tensor creation,
  1012. # where there isn't necessarily a tensor to take the device from.
  1013. torch._C._get_default_device(),
  1014. # We want to create tensors from cached metadata only when the inference
  1015. # mode is the same.
  1016. torch.is_inference_mode_enabled(),
  1017. # Shape env settings could affect behavior. One example seen in the wild:
  1018. # Disallowing dynamic shapes can introduce a DynamicOutputShapeException
  1019. # where it wasn't seen on a previous instance of the same op.
  1020. self.shape_env.settings if self.shape_env else None,
  1021. )
  1022. return _DispatchCacheKey(key_values)
  1023. def _validate_cache_key(
  1024. self,
  1025. func: OpOverload,
  1026. args: Tuple[Any, ...],
  1027. kwargs: Dict[str, Any],
  1028. ):
  1029. """
  1030. Validate that the cache key generated by _cache_key will be
  1031. reasonable.
  1032. """
  1033. # Avoid caching for any ops that would require a more sophisticated
  1034. # caching implementation, e.g., data dependent ops or ops that modify
  1035. # the inputs.
  1036. if torch.Tag.data_dependent_output in func.tags:
  1037. raise _BypassDispatchCache("data dependent output")
  1038. if torch.Tag.dynamic_output_shape in func.tags:
  1039. raise _BypassDispatchCache("dynamic output shape")
  1040. if torch.Tag.inplace_view in func.tags:
  1041. raise _BypassDispatchCache("inplace view")
  1042. if func == aten._unsafe_view.default:
  1043. raise _BypassDispatchCache("unsafe view")
  1044. if func in self.lift_fns:
  1045. raise _BypassDispatchCache("lift")
  1046. if func.name() == "inductor::resize_storage_bytes_":
  1047. raise _BypassDispatchCache("inductor::resize_storage_bytes_")
  1048. if not torch._library.utils.is_builtin(func):
  1049. raise _BypassDispatchCache("non-builtin")
  1050. # In order to handle storage aliasing, we need to establish the alias
  1051. # for any view op on a cache hit. But CompositeImplicitAutograd ops may
  1052. # or may not alias the input, so just punt on caching these.
  1053. if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key(
  1054. func.name(), torch._C.DispatchKey.CompositeImplicitAutograd
  1055. ):
  1056. raise _BypassDispatchCache("CompositeImplicitAutograd")
  1057. def _prep_args_for_hash(self, args: Any) -> Any:
  1058. """
  1059. Translate the provided args into a form suitable for caching at FakeTensor
  1060. dispatch, i.e., convert unhashable types like lists & dicts into tuples and
  1061. convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
  1062. unsupported cases that should bypass caching.
  1063. """
  1064. if isinstance(args, dict):
  1065. args = list(args.keys()) + list(args.values())
  1066. result: List[Any] = []
  1067. for arg in args:
  1068. if isinstance(arg, FakeTensor):
  1069. if not self.is_our_fake(arg):
  1070. raise _BypassDispatchCache("not our fake")
  1071. if arg._has_symbolic_sizes_strides:
  1072. raise _BypassDispatchCache("symbolic shape")
  1073. if arg.constant is not None:
  1074. raise _BypassDispatchCache("constant attribute")
  1075. if arg.is_sparse:
  1076. raise _BypassDispatchCache("sparse tensor")
  1077. if arg.layout in [
  1078. torch.sparse_csr,
  1079. torch.sparse_csc,
  1080. torch.sparse_bsr,
  1081. torch.sparse_bsc,
  1082. ]:
  1083. # Does this subsume arg.is_sparse?
  1084. raise _BypassDispatchCache("sparse tensor layout")
  1085. # sparse tensors don't have storage, so check is after
  1086. if isinstance(arg.untyped_storage().nbytes(), torch.SymInt):
  1087. raise _BypassDispatchCache("symbolic nbytes")
  1088. if is_sparse_compressed(arg):
  1089. raise _BypassDispatchCache("sparse compressed tensor")
  1090. result.append(extract_tensor_metadata(arg))
  1091. elif isinstance(arg, torch.Tensor):
  1092. raise _BypassDispatchCache("non-fake tensor")
  1093. elif isinstance(arg, (torch.SymBool, torch.SymInt, torch.SymFloat)):
  1094. raise _BypassDispatchCache("symbolic shape")
  1095. elif isinstance(arg, (list, tuple, dict)):
  1096. result.extend(self._prep_args_for_hash(arg))
  1097. else:
  1098. # It's important to capture the type of the arg since, e.g., 1 and 1.0
  1099. # hash to the same value, but can produce different dtypes for the
  1100. # output tensor.
  1101. result.append((type(arg), arg))
  1102. return tuple(result)
  1103. def _make_cache_entry(
  1104. self,
  1105. key: _DispatchCacheKey,
  1106. func: OpOverload,
  1107. args: Tuple[Any, ...],
  1108. kwargs: Dict[str, Any],
  1109. output: FakeTensor,
  1110. ) -> _DispatchCacheEntry:
  1111. """
  1112. Make a cache entry object for the given 'output' Tensor. Raises
  1113. _BypassDispatchCache if the output tensor has characteristics that
  1114. prevent caching it.
  1115. """
  1116. # Some ops return tuples of Tensors, but it's rare, so avoid
  1117. # the complexity of caching other types.
  1118. if not isinstance(output, FakeTensor):
  1119. raise _BypassDispatchCache("non-FakeTensor output")
  1120. # Avoid caching FakeTensors with constants attached since those
  1121. # can be invalidated.
  1122. if output.constant is not None:
  1123. raise _BypassDispatchCache("constant attribute")
  1124. # TODO: support caching sparse outputs?
  1125. if output.is_sparse:
  1126. raise _BypassDispatchCache("sparse output")
  1127. if is_sparse_compressed(output):
  1128. raise _BypassDispatchCache("sparse compressed output")
  1129. # Can an in-place op really reference a kwarg? If so, then we need
  1130. # to extend the implementation to handle it.
  1131. for kval in kwargs.values():
  1132. if id(kval) == id(output):
  1133. raise _BypassDispatchCache("kwarg aliases output")
  1134. # If this is an in-place op, the entry records which input arg is aliased.
  1135. for idx in range(len(args)):
  1136. if id(args[idx]) == id(output):
  1137. return _DispatchCacheEntry(
  1138. inplace_idx=idx, metadata=None, view_idx=None
  1139. )
  1140. # Otherwise, create an entry that records the output tensor's metadata.
  1141. view_idx = None
  1142. if func.is_view:
  1143. idxs = [i for i, t in enumerate(args) if isinstance(t, torch.Tensor)]
  1144. assert len(idxs) == 1
  1145. view_idx = idxs[0]
  1146. metadata = extract_tensor_metadata(output)
  1147. entry = _DispatchCacheEntry(
  1148. inplace_idx=None, metadata=metadata, view_idx=view_idx
  1149. )
  1150. # N.B.: Some checks for bypassing the cache would be performed on the
  1151. # output tensor synthesized from the cached metadata. As an optimization,
  1152. # we can synthesize a tensor here and do the checks on that instance.
  1153. # This approach keeps the (more frequent) cache-hit path as lightweight
  1154. # as possible.
  1155. synth_output = self._output_from_cache_entry(entry, func, args)
  1156. # Make sure the dispatch_key_set from the synthesized output tensor will
  1157. # be the same.
  1158. synth_key_set = torch._C._dispatch_key_set(synth_output)
  1159. key_set = torch._C._dispatch_key_set(output)
  1160. if synth_key_set != key_set:
  1161. raise _BypassDispatchCache("dispatch_key_set mismatch")
  1162. return entry
  1163. def _output_from_cache_entry(
  1164. self, entry: _DispatchCacheEntry, func: OpOverload, args: Tuple[Any, ...]
  1165. ) -> FakeTensor:
  1166. """
  1167. Create a new FakeTensor from the cache entry.
  1168. """
  1169. if entry.inplace_idx is not None:
  1170. # This is an in-place op; return the aliased arg.
  1171. return args[entry.inplace_idx]
  1172. # Synthesize a new FakeTensor with the cached metadata.
  1173. metadata = entry.metadata
  1174. assert metadata and not metadata.is_sparse
  1175. empty = torch.empty_strided(
  1176. metadata.shape,
  1177. metadata.stride,
  1178. dtype=metadata.dtype,
  1179. layout=metadata.layout,
  1180. device="meta",
  1181. requires_grad=metadata.requires_grad,
  1182. )
  1183. if metadata.is_conj:
  1184. torch._C._set_conj(empty, True)
  1185. if metadata.is_neg:
  1186. torch._C._set_neg(empty, True)
  1187. maybe_suppress: Callable[[], Any] = contextlib.nullcontext
  1188. if self.shape_env is not None:
  1189. maybe_suppress = self.shape_env.suppress_guards
  1190. if func.is_view:
  1191. # For view ops, the storage should be the same as the tensor input.
  1192. storage = args[cast(int, entry.view_idx)].untyped_storage()
  1193. with in_kernel_invocation_manager(self), maybe_suppress():
  1194. empty.set_(
  1195. storage, metadata.storage_offset, metadata.shape, metadata.stride
  1196. )
  1197. elif metadata.storage_offset != 0:
  1198. storage = empty.untyped_storage()
  1199. with in_kernel_invocation_manager(self), maybe_suppress():
  1200. empty.set_(
  1201. storage, metadata.storage_offset, metadata.shape, metadata.stride
  1202. )
  1203. if metadata.storage_bytes == 0:
  1204. empty.untyped_storage().resize_(0)
  1205. return FakeTensor(self, empty, metadata.device)
  1206. def _crosscheck_cache_output(
  1207. self,
  1208. output: FakeTensor,
  1209. func: OpOverload,
  1210. types: Tuple[Any, ...],
  1211. args: Tuple[Any, ...],
  1212. kwargs: Dict[str, Any],
  1213. ):
  1214. """
  1215. Helper to validate that the output synthesized from the cache matches
  1216. the output created by normal dispatch.
  1217. """
  1218. try:
  1219. true_output = self._dispatch_impl(func, types, args, kwargs)
  1220. except Exception as e:
  1221. raise RuntimeError(
  1222. f"FakeTensor cache crosscheck failure: func={func}, "
  1223. f"args={args}, kwargs={kwargs}: Dispatch raised={e}"
  1224. ) from e
  1225. try:
  1226. assert_metadata_eq(assert_eq, true_output, output)
  1227. except Exception as e:
  1228. raise RuntimeError(
  1229. f"FakeTensor cache crosscheck failure: func={func}, "
  1230. f"args={args}, kwargs={kwargs}"
  1231. ) from e
  1232. def dispatch(self, func, types, args=(), kwargs=None):
  1233. kwargs = kwargs or {}
  1234. with no_dispatch():
  1235. log.debug("%s %s %s", func, args, kwargs)
  1236. if func in _DISPATCH_META_HANDLERS:
  1237. return _DISPATCH_META_HANDLERS[func](args)
  1238. if log.getEffectiveLevel() <= logging.DEBUG:
  1239. log.debug(
  1240. "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func
  1241. )
  1242. # NOTE: incr is intentionally unused for a RAII pattern
  1243. incr = IncrementRecursionCount()
  1244. # Some attribute queries that can be serviced directly
  1245. # See Note [is_coalesced is dispatched]
  1246. if func in _DISPATCH_HANDLE_DIRECTLY:
  1247. # NB: no_dispatch is ok here too, this func is very simple
  1248. with in_kernel_invocation_manager(self):
  1249. return func(*args, **kwargs)
  1250. if self.cache_enabled:
  1251. return self._cached_dispatch_impl(func, types, args, kwargs)
  1252. else:
  1253. return self._dispatch_impl(func, types, args, kwargs)
  1254. def _dispatch_impl(self, func, types, args, kwargs) -> FakeTensor:
  1255. flat_args, args_spec = pytree.tree_flatten((args, kwargs))
  1256. flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)]
  1257. has_symbolic_sizes = any(
  1258. i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
  1259. ) or any(isinstance(a, torch.SymInt) for a in flat_args)
  1260. converter = self.fake_tensor_converter
  1261. is_lift_func = func in self.lift_fns
  1262. # To constant propagate through these functions:
  1263. # 1, If this is a lift due to a torch.tensor call,
  1264. # the input tensor is guaranteed to be a
  1265. # constant, so we keep a copy of the original argument along so
  1266. # we can query it if we're asked to item() it at some later point.
  1267. # (Note that you can always call a lift fn manually, so we do
  1268. # have to check if there are any fake tensors!)
  1269. # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div
  1270. if (is_lift_func and not flat_arg_fake_tensors) or (
  1271. should_allow_numbers_as_tensors(func)
  1272. and not has_symbolic_sizes
  1273. and not flat_arg_fake_tensors
  1274. ):
  1275. assert all(
  1276. t.constant is not None for t in flat_arg_fake_tensors
  1277. ), f"{func} should not have fake inputs without constants"
  1278. const_flat_args = [
  1279. a.constant if self.is_our_fake(a) else a for a in flat_args
  1280. ]
  1281. const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)
  1282. out = func(*const_args, **const_kwargs)
  1283. if type(out) is torch.Tensor and self.may_turn_const(out):
  1284. # NB: not in_kernel_invocation_manager because we're doing real
  1285. # compute here
  1286. # NB: no_dispatch() here is VERY DANGEROUS (like, segfault
  1287. # dangerous) if this is actually a wrapper subclass tensor,
  1288. # therefore the exact type test above
  1289. with no_dispatch():
  1290. out = out.clone()
  1291. return converter.from_real_tensor(self, out, make_constant=True)
  1292. # See [subclass inputs] below
  1293. # NB: If you're seeing a mysterious infinite loop involving fake
  1294. # tensor, it might be related to this line. Though I'm not sure
  1295. # how you'll know to read this comment, as this line won't show up
  1296. # in the stack trace.
  1297. has_unrecognized_types = _check_for_subclass(flat_args)
  1298. if has_unrecognized_types:
  1299. unrecognized_types = [
  1300. type(x) for x in flat_args if _check_for_subclass_arg(x)
  1301. ]
  1302. not_implemented_log.debug(
  1303. "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types
  1304. )
  1305. return NotImplemented
  1306. # if we are in the dispatch mode, we will enter this function even if the inputs
  1307. # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
  1308. # and just support constructors.
  1309. # this is generated from torch.tensor(), which does not use the
  1310. # dispatcher, to allow wrapper subclasses to wrap the new tensor
  1311. if is_lift_func:
  1312. assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}"
  1313. if type(args[0]) is torch.Tensor:
  1314. return converter.from_real_tensor(self, args[0])
  1315. # If we are trying to avoid device init, then we need to avoid constant
  1316. # prop on constant tensors for ops that change devices.
  1317. avoiding_device_init = False
  1318. if self.avoid_device_init:
  1319. if (
  1320. func == torch.ops.aten._to_copy.default
  1321. and "device" in kwargs
  1322. and kwargs["device"] != "cpu"
  1323. ):
  1324. avoiding_device_init = True
  1325. if func == torch.ops.prims.device_put.default:
  1326. avoiding_device_init = True
  1327. # Recompute flat_arg_fake_tensors here again in case some of the inputs
  1328. # were real tensors and fakified in validate_and_convert_non_fake_tensors
  1329. (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
  1330. func, converter, flat_args, args_spec
  1331. )
  1332. del args, kwargs # Invalidated
  1333. # The current constant handling only support tracing systems
  1334. # (aot autograd, torchdynamo) where each operation is run consecutively.
  1335. # Because each operation is run in order, we can trace out and support
  1336. # sequences like: x = torch.tensor(0.); y = x.add_(1)
  1337. # Whenver a constant is written to but with inputs that cannot be evaluated
  1338. # statically, such as random_(), we invalidate all constants that alias the input
  1339. # We will rely on functionalization for use of fake tensors constants as persistent
  1340. # objects on an FX Graph.
  1341. # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view
  1342. all_constant = all(e.constant is not None for e in flat_arg_fake_tensors)
  1343. if (
  1344. torch.Tag.nondeterministic_seeded not in func.tags
  1345. and torch.Tag.inplace_view not in func.tags
  1346. and all_constant
  1347. and len(flat_arg_fake_tensors) != 0
  1348. and not has_symbolic_sizes
  1349. and not avoiding_device_init
  1350. ):
  1351. const_flat_args = [
  1352. a.constant if self.is_our_fake(a) else a for a in flat_args
  1353. ]
  1354. const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)
  1355. # NB: not in_kernel_invocation_manager(self) as we want to do REAL
  1356. # compute
  1357. with no_dispatch():
  1358. out = func(*const_args, **const_kwargs)
  1359. flat_out = pytree.tree_leaves(out)
  1360. flat_out_tensors = [t for t in flat_out if isinstance(t, torch.Tensor)]
  1361. all_constant = all(self.may_turn_const(t) for t in flat_out_tensors)
  1362. if all_constant:
  1363. return pytree.tree_map_only(
  1364. torch.Tensor,
  1365. lambda t: converter.from_real_tensor(self, t, make_constant=True),
  1366. out,
  1367. )
  1368. # we weren't able to turn outputs to constants,
  1369. # so invalidate all constants that might be aliases of the outputs
  1370. for ten in flat_out_tensors:
  1371. converter.invalidate_constant_aliases(ten)
  1372. # we are falling through to running non constant tensors, any input constant that
  1373. # is written to must be invalidated
  1374. args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
  1375. self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  1376. def maybe_to_real_tensor(t):
  1377. if isinstance(t, FakeTensor):
  1378. return t.real_tensor
  1379. elif isinstance(t, SymTypes):
  1380. return t.node.pytype(
  1381. t.node.expr.xreplace(self.shape_env.var_to_val).xreplace(
  1382. self.shape_env.unbacked_var_to_val
  1383. )
  1384. )
  1385. else:
  1386. return t
  1387. from torch.fx.experimental.symbolic_shapes import (
  1388. compute_unbacked_bindings,
  1389. free_unbacked_symbols,
  1390. SymTypes,
  1391. )
  1392. nil = object()
  1393. real_out = nil
  1394. if (
  1395. self.propagate_real_tensors
  1396. and all(e.real_tensor is not None for e in flat_arg_fake_tensors)
  1397. # TODO: Handle SymFloat/SymBool
  1398. and not any(
  1399. (
  1400. isinstance(a, torch.SymInt)
  1401. and (syms := free_unbacked_symbols(a))
  1402. and any(s not in self.shape_env.unbacked_var_to_val for s in syms)
  1403. )
  1404. for a in flat_args
  1405. )
  1406. ):
  1407. real_flat_args = [maybe_to_real_tensor(a) for a in flat_args]
  1408. real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec)
  1409. real_out = func(*real_args, **real_kwargs)
  1410. elif self.propagate_real_tensors:
  1411. # This can happen occasionally legitimately, specifically when you
  1412. # are inside the meta of a data dependent operation and you create
  1413. # a tensor on an unbacked SymInt; at this point in time we don't
  1414. # know what the unbacked SymInt is, but we will know later.
  1415. # However, if there's a bug in the condition above, this condition
  1416. # will also trigger.
  1417. log.debug(
  1418. "propagate_real_tensors skipped %s(%s, %s) %s",
  1419. func,
  1420. flat_arg_fake_tensors,
  1421. flat_args,
  1422. self.shape_env.unbacked_var_to_val if self.shape_env else None,
  1423. )
  1424. def maybe_propagate_real_tensors(fake_out):
  1425. import sympy
  1426. def go(t, real_t):
  1427. if isinstance(t, FakeTensor):
  1428. # NB: unconditionally overwrite
  1429. t.real_tensor = real_t
  1430. elif isinstance(t, SymTypes) and free_unbacked_symbols(t):
  1431. if isinstance(t.node.expr, sympy.Symbol):
  1432. self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t)
  1433. if real_out is not nil:
  1434. tree_map_(go, fake_out, real_out)
  1435. # If a data-dependent op is used in a decomposition, we
  1436. # may need to get the unbacked settings "early"
  1437. # TODO: Is this really needed?
  1438. compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
  1439. return fake_out
  1440. # Try for fastpath
  1441. if has_symbolic_sizes:
  1442. fast_impl = get_fast_op_impls().get(func)
  1443. if fast_impl is not None:
  1444. return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
  1445. # If there's a Python meta, prefer that over the decomposition
  1446. from torch._decomp import meta_table as meta_table
  1447. if func not in meta_table and not self.cpp_meta_supports_symint(func):
  1448. from torch._decomp import decomposition_table
  1449. # Prefer Python decompositions over C++ ones
  1450. if func in decomposition_table and (
  1451. has_symbolic_sizes
  1452. or (
  1453. # TODO: Remove these exclusions, so that we can remove
  1454. # this leg entirely
  1455. torch_decomp_decompositions(func)
  1456. and all(not e.is_sparse for e in flat_arg_fake_tensors)
  1457. )
  1458. ):
  1459. with self:
  1460. return decomposition_table[func](*args, **kwargs)
  1461. with self:
  1462. # Decomposes CompositeImplicitAutograd ops
  1463. r = func.decompose(*args, **kwargs)
  1464. if r is not NotImplemented:
  1465. return r
  1466. # prims already wrap FakeTensor inputs to FakeTensor outputs
  1467. # and do device logic, we dont need do anything but run them
  1468. # and ensure that Meta kernels are dispatched to (see)
  1469. # Fake Tensor Dispatch Keys
  1470. # TODO - we should be use the prim aten impl
  1471. # TODO - fix prims complex ops
  1472. if (
  1473. "prims::" in func._schema.name
  1474. and hasattr(func, "prim_meta_impl")
  1475. and not stride_incorrect_op(func)
  1476. ):
  1477. with self:
  1478. return maybe_propagate_real_tensors(
  1479. func.prim_meta_impl(*args, **kwargs)
  1480. )
  1481. # Users can register FakeTensor rules for custom operators
  1482. # Call them if they exist.
  1483. maybe_abstract_impl = torch._library.simple_registry.singleton.find(
  1484. func.name()
  1485. ).abstract_impl.kernel
  1486. if maybe_abstract_impl:
  1487. ctx = torch._library.abstract_impl.AbstractImplCtx(self, func)
  1488. with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self:
  1489. result = maybe_abstract_impl(*args, **kwargs)
  1490. return maybe_propagate_real_tensors(result)
  1491. # special handling for funcs registered through `register_op_impl`,
  1492. # e.g., manipulating args on constructor calls to construct meta tensors
  1493. # and then afterwards wrapping them to a FakeTensor
  1494. for run_impl_check, op_impl in op_implementations_checks:
  1495. if run_impl_check(func):
  1496. op_impl_out = op_impl(self, func, *args, **kwargs)
  1497. if op_impl_out is not NotImplemented:
  1498. return maybe_propagate_real_tensors(op_impl_out)
  1499. def maybe_run_unsafe_fallback(error=None):
  1500. # We infer the meta of a custom ops that return None to just
  1501. # return None. custom ops are not allowed to mutate metadata
  1502. # of their inputs, so this is safe.
  1503. if torch._library.utils.can_generate_trivial_fake_impl(func):
  1504. return None
  1505. # no meta kernel registered, fallback to kernel for the device
  1506. if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
  1507. raise UnsupportedOperatorException(func)
  1508. if error is None:
  1509. error = UnsupportedOperatorException(func)
  1510. return run_fallback_kernel(self, func, flat_args, args_spec, error)
  1511. # Optimization: If there is no Meta kernel, it takes a surprisingly long
  1512. # amount of time to catch the NotImplementedError, so we check it here.
  1513. if not has_meta(func):
  1514. return maybe_propagate_real_tensors(maybe_run_unsafe_fallback())
  1515. # run kernel registered to meta for func, which include
  1516. # python meta registrations, prims, decomps, and c++ meta fns (structured kernels)
  1517. # It's possible that the kernel will return NotImplementedError
  1518. try:
  1519. with in_kernel_invocation_manager(self):
  1520. r = func(*args, **kwargs)
  1521. except NotImplementedError as not_implemented_error:
  1522. return maybe_run_unsafe_fallback(not_implemented_error)
  1523. except Exception:
  1524. log.exception("failed while attempting to run meta for %s", func)
  1525. raise
  1526. return maybe_propagate_real_tensors(
  1527. self.wrap_meta_outputs_with_default_device_logic(
  1528. r, func, flat_args, device=kwargs.get("device")
  1529. )
  1530. )
  1531. # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators
  1532. # outside of the pytorch/pytorch library! Any pre-existing things here
  1533. # are either in the pytorch/pytorch library or have been grandfathered in.
  1534. # The fallback does not always work and MAY CRASH and emit unreadable error messages
  1535. # so it should not be allowed by default.
  1536. _can_run_unsafe_fallback_allowed_namespaces = ordered_set(
  1537. "debugprims",
  1538. "prims",
  1539. "aten",
  1540. "xla",
  1541. "vision",
  1542. "torchtext",
  1543. "torchaudio",
  1544. "quantized",
  1545. )
  1546. def can_run_unsafe_fallback(self, func: OpOverload):
  1547. if not self.allow_fallback_kernels:
  1548. return False
  1549. # It's OK to try the fallback for built-in ops (e.g. aten, prims)
  1550. # because we control and test these but the fallback leads to unexpected behavior
  1551. # in user-defined custom ops
  1552. return (
  1553. func.namespace in self._can_run_unsafe_fallback_allowed_namespaces
  1554. or func.name() == "fbgemm::gmm"
  1555. )
  1556. def validate_and_convert_non_fake_tensors(
  1557. self, func, converter, flat_args, args_spec
  1558. ):
  1559. """
  1560. Checks if the list of tensors are fake tensors.
  1561. If not, try to convert them to fake tensors.
  1562. Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors.
  1563. """
  1564. flat_arg_fake_tensors: List[Any] = []
  1565. def validate(x):
  1566. if not isinstance(x, torch.Tensor):
  1567. return x
  1568. nonlocal flat_arg_fake_tensors
  1569. if not self.is_our_fake(x):
  1570. if torch.Tag.inplace_view in func.tags:
  1571. args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
  1572. raise AssertionError(
  1573. f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}"
  1574. )
  1575. if not self.allow_non_fake_inputs:
  1576. if isinstance(x, FakeTensor) and x.fake_mode is not self:
  1577. raise AssertionError("Mixing fake modes NYI")
  1578. args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
  1579. raise AssertionError(
  1580. f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode "
  1581. f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}"
  1582. )
  1583. x = converter.from_real_tensor(self, x)
  1584. flat_arg_fake_tensors.append(x)
  1585. return x
  1586. validated_args = [validate(a) for a in flat_args]
  1587. return validated_args, flat_arg_fake_tensors
  1588. def wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device):
  1589. converter = self.fake_tensor_converter
  1590. # Lazily initialized, in case there are no tensor returns
  1591. common_device = None
  1592. has_scalar_only_inputs = False
  1593. def wrap(e):
  1594. nonlocal common_device
  1595. nonlocal has_scalar_only_inputs
  1596. if not isinstance(e, torch.Tensor):
  1597. return e
  1598. if common_device is None:
  1599. (
  1600. common_device,
  1601. has_scalar_only_inputs,
  1602. ) = FakeTensor._find_common_device(func, flat_args)
  1603. is_our_fake = self.is_our_fake(e)
  1604. if is_our_fake:
  1605. torch._check(
  1606. e.device == common_device,
  1607. lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
  1608. )
  1609. return e
  1610. elif converter is not None:
  1611. if has_scalar_only_inputs:
  1612. # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div,
  1613. # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details.
  1614. # We thus directly convert real tensor to fake tensor.
  1615. return converter.from_real_tensor(self, e)
  1616. else:
  1617. return converter.from_meta_and_device(
  1618. self, e, device or common_device
  1619. )
  1620. else:
  1621. return e
  1622. return tree_map(wrap, r)
  1623. _cpp_meta_supports_symint = ordered_set(
  1624. aten.empty.memory_format,
  1625. aten.empty_strided.default,
  1626. aten.as_strided_scatter.default,
  1627. aten.as_strided.default,
  1628. aten.as_strided_.default,
  1629. aten.zeros.default,
  1630. aten.detach.default,
  1631. aten.view_as_real.default,
  1632. aten.view_as_complex.default,
  1633. aten.set_.source_Storage_storage_offset,
  1634. aten._sparse_coo_tensor_with_dims_and_tensors.default,
  1635. )
  1636. def cpp_meta_supports_symint(self, func):
  1637. if torch.Tag.view_copy in func.tags:
  1638. return True
  1639. return func in self._cpp_meta_supports_symint
  1640. lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default)
  1641. def may_turn_const(self, t):
  1642. return (
  1643. t.numel() <= CONSTANT_NUMEL_LIMIT
  1644. and not t.is_sparse
  1645. and not self.is_our_fake(t)
  1646. and not t.device.type == "meta"
  1647. )
  1648. def invalidate_written_to_constants(
  1649. self, func, flat_arg_fake_tensors, args, kwargs
  1650. ):
  1651. any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
  1652. schema_info = get_schema_info(func)
  1653. if any_constant and schema_info.is_mutable():
  1654. _, new_kwargs = normalize_function(
  1655. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1656. )
  1657. for k, v in new_kwargs.items():
  1658. k = k if (k != "input" or schema_info.has_argument(k)) else "self"
  1659. if (
  1660. self.is_our_fake(v)
  1661. and schema_info.is_mutable(k)
  1662. and v.constant is not None
  1663. ):
  1664. self.fake_tensor_converter.invalidate_constant_aliases(v.constant)
  1665. def from_tensor(
  1666. self,
  1667. tensor,
  1668. *,
  1669. static_shapes=None,
  1670. source: Optional[Source] = None,
  1671. symbolic_context=None,
  1672. trace=True,
  1673. ):
  1674. shape_env: Optional[ShapeEnv] = self.shape_env
  1675. if static_shapes is None:
  1676. static_shapes = self.static_shapes
  1677. if static_shapes:
  1678. assert (
  1679. symbolic_context is None
  1680. ), "cannot set both static_shapes and symbolic_context"
  1681. shape_env = None
  1682. return self.fake_tensor_converter.from_real_tensor(
  1683. self,
  1684. tensor,
  1685. shape_env=shape_env,
  1686. source=source,
  1687. symbolic_context=symbolic_context,
  1688. trace=trace,
  1689. )
  1690. # NB: returns fake tensors
  1691. def run_fallback_kernel(
  1692. fake_mode, func, flat_args, args_spec, orig_not_implemented_exception
  1693. ):
  1694. # these should all be supported, just to be safe
  1695. # avoid fallback for operators which inplace modify metadata
  1696. # because the input fake tensors would be umodified
  1697. if torch.Tag.inplace_view in func.tags:
  1698. raise orig_not_implemented_exception
  1699. inp_impls = {}
  1700. # Don't use in_kernel_invocation_manager(fake_mode) as we want to do
  1701. # REAL compute (not with meta device)
  1702. with no_dispatch():
  1703. def to_real_tensor(e):
  1704. if fake_mode.is_our_fake(e):
  1705. out = torch.zeros_like(e, device=e.fake_device)
  1706. if e.is_sparse:
  1707. out._coalesced_(e.is_coalesced())
  1708. inp_impls[id(out)] = e
  1709. return out
  1710. return e
  1711. flat_args = [to_real_tensor(a) for a in flat_args]
  1712. args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
  1713. r = func(*args, **kwargs)
  1714. tensor_impls = set()
  1715. storages = set()
  1716. for e in flat_args:
  1717. if isinstance(e, torch.Tensor):
  1718. if not e.is_sparse:
  1719. storages.add(e._typed_storage()._cdata)
  1720. # TODO: also check metadata change on inputs
  1721. # proper aliasing/metadata relationship between outputs and inputs will
  1722. # not be set up, bc of conversion to device, unless we can reuse an
  1723. # input impl
  1724. def map_out(e):
  1725. if id(e) not in inp_impls and (
  1726. isinstance(e, torch.Tensor)
  1727. and not e.is_sparse
  1728. and e._typed_storage()._cdata in storages
  1729. ):
  1730. raise orig_not_implemented_exception
  1731. if isinstance(e, torch.Tensor):
  1732. if id(e) in inp_impls:
  1733. return inp_impls[id(e)]
  1734. else:
  1735. return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e)
  1736. else:
  1737. return e
  1738. return pytree.tree_map(map_out, r)
  1739. # Just for use to allow copying a module to fake tensors,
  1740. # does not apply elsewhere
  1741. class FakeCopyMode(TorchFunctionMode):
  1742. def __init__(self, fake_mode):
  1743. self.fake_mode = fake_mode
  1744. def __torch_function__(self, func, types, args=(), kwargs=None):
  1745. kwargs = kwargs if kwargs else {}
  1746. # clone will get called in Parameter deepcopy
  1747. if func == torch._C.TensorBase.clone:
  1748. return func(
  1749. self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
  1750. )
  1751. elif func == torch.Tensor.__deepcopy__:
  1752. assert len(args) == 2 and len(kwargs) == 0
  1753. tensor, memo = args
  1754. if id(tensor) in memo:
  1755. return memo[id(tensor)]
  1756. out = self.fake_mode.from_tensor(tensor, static_shapes=True)
  1757. memo[id(tensor)] = out
  1758. return out
  1759. else:
  1760. with torch._C.DisableTorchFunctionSubclass():
  1761. return func(*args, **kwargs)
  1762. def _device_handler(args):
  1763. # NB: Don't use is_our_fake, just serve the fake information
  1764. # as is. Notice we don't use 'self'; we use args[0].fake_mode
  1765. # because they may not be the same. It would also be possible
  1766. # to return NotImplemented here, in which case the FakeTensor
  1767. # handler on args[0] would handle it, but we're being nice and
  1768. # short-circuiting quickly.
  1769. assert len(args) == 1 and isinstance(args[0], FakeTensor)
  1770. if args[0].fake_mode.in_kernel_invocation:
  1771. return torch.device("meta")
  1772. else:
  1773. return args[0].fake_device
  1774. # [subclass inputs]
  1775. # Suppose we enable fake tensor mode. This means that fake tensor
  1776. # mode will run first. But what if we do an operation that
  1777. # involves a tensor subclass that will desugar into normal tensor
  1778. # operations? Without returning NotImplemented, fake tensor mode will run first,
  1779. # decide that a conversion was made (since there was a non fake
  1780. # tensor argument), and report an error that converting non
  1781. # fake tensor is not supported. What we actually wanted to happen
  1782. # was to give the subclass a chance to figure out what it wants to
  1783. # before erroring out. Returning NotImplemented here allows this.
  1784. def _check_for_subclass(flat_args):
  1785. return any(_check_for_subclass_arg(x) for x in flat_args)
  1786. def _check_for_subclass_arg(x):
  1787. return (
  1788. not isinstance(x, FakeTensor)
  1789. and isinstance(x, torch.Tensor)
  1790. and type(x) is not torch.Tensor
  1791. and type(x) is not torch.nn.Parameter
  1792. )
  1793. _DISPATCH_META_HANDLERS = {
  1794. torch.ops.prim.device.default: _device_handler,
  1795. torch.ops.aten.size.default: lambda args: tuple(int(s) for s in args[0].size()),
  1796. torch.ops.aten.stride.default: lambda args: tuple(int(s) for s in args[0].stride()),
  1797. torch.ops.aten.storage_offset.default: lambda args: int(args[0].storage_offset()),
  1798. }
  1799. _DISPATCH_HANDLE_DIRECTLY = ordered_set(
  1800. torch.ops.aten.is_coalesced.default,
  1801. torch.ops.aten.dense_dim.default,
  1802. torch.ops.aten.sparse_dim.default,
  1803. )
  1804. from torch._subclasses.fake_impls import ( # noqa: F401
  1805. _device_not_kwarg_ops, # noqa: F401
  1806. _is_tensor_constructor, # noqa: F401
  1807. _like_tensor_constructors, # noqa: F401
  1808. contains_tensor_types, # noqa: F401
  1809. get_fast_op_impls,
  1810. has_meta,
  1811. op_implementations_checks,
  1812. stride_incorrect_op,
  1813. )