builder.py 101 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438
  1. # mypy: ignore-errors
  2. import abc
  3. import collections
  4. import contextlib
  5. import dataclasses
  6. import enum
  7. import functools
  8. import inspect
  9. import itertools
  10. import logging
  11. import math
  12. import operator
  13. import re
  14. import sys
  15. import types
  16. from typing import Any, List, NamedTuple, Optional, Union
  17. from torch.utils._sympy.value_ranges import ValueRanges
  18. try:
  19. import numpy as np
  20. except ModuleNotFoundError:
  21. np = None
  22. import torch
  23. from torch import SymInt
  24. from torch._guards import GuardSource, TracingContext
  25. from torch._higher_order_ops.torchbind import call_torchbind
  26. from torch._ops import HigherOrderOperator
  27. from torch._streambase import _EventBase, _StreamBase
  28. from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
  29. from torch._subclasses.meta_utils import is_sparse_any
  30. from torch.fx.experimental._backward_state import BackwardState
  31. from torch.fx.experimental.symbolic_shapes import (
  32. _constrain_range_for_size,
  33. DimDynamic,
  34. RelaxedUnspecConstraint,
  35. StatefulSymbolicContext,
  36. SubclassSymbolicContext,
  37. SymbolicContext,
  38. )
  39. from torch.fx.immutable_collections import immutable_dict, immutable_list
  40. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  41. from torch.utils.weak import TensorWeakRef
  42. from .. import config, mutation_guard, replay_record, trace_rules
  43. from ..device_interface import get_registered_device_interfaces
  44. from ..exc import InternalTorchDynamoError, unimplemented
  45. from ..guards import GuardBuilder, install_guard, make_dupe_guard
  46. from ..side_effects import SideEffects
  47. from ..source import (
  48. AttrSource,
  49. CallMethodItemSource,
  50. ConstantSource,
  51. ConstDictKeySource,
  52. ConvertIntSource,
  53. FloatTensorSource,
  54. GetItemSource,
  55. GradSource,
  56. is_cell_contents,
  57. is_constant_source,
  58. is_from_defaults,
  59. is_from_optimizer_source,
  60. LocalSource,
  61. NumpyTensorSource,
  62. OptimizerSource,
  63. RandomValueSource,
  64. Source,
  65. TupleIteratorGetItemSource,
  66. )
  67. from ..trace_rules import (
  68. is_callable_allowed,
  69. is_numpy,
  70. is_numpy_dtype,
  71. is_numpy_type_info,
  72. )
  73. from ..utils import (
  74. build_checkpoint_variable,
  75. clone_input,
  76. common_constant_types,
  77. get_fake_value,
  78. get_locals_to_steal,
  79. get_static_address_type,
  80. is_function_or_wrapper,
  81. is_namedtuple,
  82. is_typing,
  83. is_utils_checkpoint,
  84. istype,
  85. odict_values,
  86. proxy_args_kwargs,
  87. set_example_value,
  88. tensor_always_has_static_shape,
  89. tuple_iterator,
  90. tuple_iterator_getitem,
  91. tuple_iterator_len,
  92. unwrap_with_attr_name_if_wrapper,
  93. wrap_fake_exception,
  94. )
  95. from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta
  96. from .constant import ConstantVariable, EnumVariable
  97. from .ctx_manager import (
  98. AutocastModeVariable,
  99. EventVariable,
  100. NullContextVariable,
  101. PreserveVersionContextVariable,
  102. StreamContextVariable,
  103. StreamVariable,
  104. )
  105. from .dicts import (
  106. ConstDictVariable,
  107. DataClassVariable,
  108. DefaultDictVariable,
  109. HFPretrainedConfigVariable,
  110. PythonSysModulesVariable,
  111. SetVariable,
  112. )
  113. from .distributed import (
  114. DeviceMeshVariable,
  115. PlacementClassVariable,
  116. PlacementVariable,
  117. ProcessGroupVariable,
  118. WorldMetaClassVariable,
  119. )
  120. from .functions import (
  121. CollectiveFunctionRewriteVariable,
  122. FunctoolsPartialVariable,
  123. TritonKernelVariable,
  124. UserMethodVariable,
  125. )
  126. from .higher_order_ops import TorchHigherOrderOperatorVariable
  127. from .iter import ItertoolsVariable
  128. from .lazy import LazyVariableTracker
  129. from .lists import (
  130. BaseListVariable,
  131. ListVariable,
  132. NamedTupleVariable,
  133. RangeVariable,
  134. RestrictedListSubclassVariable,
  135. SizeVariable,
  136. SliceVariable,
  137. TupleIteratorVariable,
  138. TupleVariable,
  139. )
  140. from .misc import (
  141. AutogradFunctionContextVariable,
  142. AutogradFunctionVariable,
  143. ComptimeVariable,
  144. DebuggingVariable,
  145. DelayGraphBreakVariable,
  146. GetAttrVariable,
  147. GetSetDescriptorVariable,
  148. InspectSignatureVariable,
  149. LambdaVariable,
  150. LoggingLoggerVariable,
  151. MethodWrapperVariable,
  152. NumpyDTypeVariable,
  153. NumpyTypeInfoVariable,
  154. NumpyVariable,
  155. PythonModuleVariable,
  156. RegexPatternVariable,
  157. SavedTensorBox,
  158. TorchVersionVariable,
  159. TypingVariable,
  160. )
  161. from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
  162. from .optimizer import OptimizerVariable
  163. from .script_object import TorchScriptObjectVariable
  164. from .sdpa import SDPAParamsVariable
  165. from .tensor import (
  166. NumpyNdarrayVariable,
  167. SymNodeVariable,
  168. TensorSubclassVariable,
  169. TensorVariable,
  170. UnspecializedPythonVariable,
  171. )
  172. from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
  173. from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable
  174. from .user_defined import (
  175. KeyedJaggedTensorVariable,
  176. SourcelessGraphModuleVariable,
  177. UserDefinedClassVariable,
  178. UserDefinedObjectVariable,
  179. )
  180. log = logging.getLogger(__name__)
  181. DimList = List
  182. class _missing:
  183. pass
  184. @dataclasses.dataclass
  185. class GraphArg:
  186. source: Source
  187. # TODO: storing a SymInt here but not a FakeTensor is a pretty strange
  188. # thing to do. Probably should have example (which stores an int) and
  189. # fake_example
  190. _example: Union[TensorWeakRef, torch.SymInt]
  191. # When True, this indicates that this GraphArg is a Python quantity (e.g.,
  192. # a float or int) which we pass to the FX graph as a Tensor. This
  193. # controls how we codegen calls into the Dynamo graph: we will call
  194. # torch.as_tensor on the quantity before passing it in.
  195. #
  196. # Note that we typically do not pass dynamic integers as tensors, because
  197. # they will most frequently just be used for size computation. But this
  198. # is a policy decision that we can change our mind on; in particular, when
  199. # an int comes from a random number generator (e.g., random.randint), we
  200. # DO pass it as a tensor.
  201. #
  202. # It's also worth noting that our current tracing rules for
  203. # pass_arg_as_tensor as subtly broken: we just pun the variable as a
  204. # 0d scalar Tensor and pray that the semantics are the same. Which they
  205. # often are, but not necessarily. ezyang(May 2024) plans to fix this
  206. # soon.
  207. pass_arg_as_tensor: bool
  208. fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
  209. # UnspecializedPythonVariable often masquerades as a tensor.
  210. # We MUST NOT generate shape guard code
  211. # that actually tries to access tensor properties on these values.
  212. # is_tensor lets us tell if this graph arg actually is a tensor
  213. # or not.
  214. is_tensor: bool = True
  215. # Sometimes, the Tensor we pass to example is freshly allocated (smh).
  216. # Then we cannot only keep a weak reference to it. This lets you
  217. # stash a strong reference too.
  218. example_strong_ref: Optional[torch.Tensor] = None
  219. @property
  220. def example(self):
  221. if isinstance(self._example, TensorWeakRef):
  222. r = self._example()
  223. assert r is not None
  224. return r
  225. else:
  226. return self._example
  227. def __post_init__(self):
  228. if isinstance(self._example, torch.Tensor):
  229. self._example = TensorWeakRef(self._example)
  230. assert is_fake(self.fake_tensor)
  231. def reconstruct(self, codegen):
  232. self.source.reconstruct(codegen)
  233. def erase(self):
  234. self._example = None
  235. self.example_strong_ref = None
  236. def __eq__(self, other):
  237. return self.source.name() == other.source.name()
  238. class BackwardStateGraphArg(GraphArg):
  239. def __init__(self):
  240. super().__init__(
  241. source=None,
  242. _example=BackwardState(),
  243. pass_arg_as_tensor=False,
  244. fake_tensor=None,
  245. is_tensor=False,
  246. )
  247. def reconstruct(self, codegen):
  248. assert codegen.tx.output.backward_state_var
  249. codegen.load_import_from(BackwardState.__module__, "BackwardState")
  250. codegen.call_function(0, True)
  251. codegen.dup_top()
  252. codegen.store(codegen.tx.output.backward_state_var)
  253. @dataclasses.dataclass
  254. class FrameStateSizeEntry:
  255. scalar: Optional[int]
  256. size: Optional[List[int]]
  257. class VariableBuilder:
  258. """Wrap a python value in a VariableTracker() instance"""
  259. def __init__(
  260. self,
  261. tx,
  262. source: Source,
  263. ):
  264. assert (
  265. source is not None
  266. ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
  267. assert TracingContext.try_get() is not None, "Expected active TracingContext"
  268. super().__init__()
  269. self.tx = tx
  270. self.source = source
  271. self.name = source.name()
  272. def __call__(self, value):
  273. if value in self.tx.output.side_effects:
  274. side_effect_result = self.tx.output.side_effects[value]
  275. dup_guard = make_dupe_guard(self.source, side_effect_result.source)
  276. if dup_guard:
  277. self.install_guards(dup_guard)
  278. return side_effect_result
  279. cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)
  280. if cached_vt:
  281. return cached_vt
  282. vt = self._wrap(value)
  283. vt.source = self.source
  284. if self._can_lift_attrs_to_inputs(vt):
  285. vt = self.tx.output.side_effects.track_object_existing(value, vt)
  286. self.tx.output.variable_tracker_cache.add(value, self.source, vt)
  287. return vt
  288. def _can_lift_attrs_to_inputs(self, vt):
  289. if type(vt) in [
  290. TensorVariable,
  291. TensorWithTFOverrideVariable,
  292. UserDefinedObjectVariable,
  293. NumpyNdarrayVariable,
  294. ]:
  295. return True
  296. return False
  297. @staticmethod
  298. @functools.lru_cache(None)
  299. def _common_constants():
  300. return {
  301. # We zero-one specialize shapes, so specialize these constants
  302. # too
  303. 0,
  304. 1,
  305. # NB: There used to be more constants here, but honestly it was
  306. # pretty confusing. Note we specialize floats by default, and
  307. # DON'T specialize ints by default. This all only matters with
  308. # dynamic_shapes
  309. }
  310. def get_source(self):
  311. return self.source
  312. def install_guards(self, *guards):
  313. source = self.get_source()
  314. if (
  315. isinstance(source, ConstantSource)
  316. or source.guard_source() == GuardSource.CONSTANT
  317. ):
  318. return None
  319. install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
  320. return {}
  321. def set_source_and_track_mutable(self, value, var):
  322. assert isinstance(var, VariableTracker)
  323. var.source = self.source
  324. return self.tx.output.side_effects.track_mutable(value, var)
  325. @classmethod
  326. @functools.lru_cache(None)
  327. def _type_dispatch(cls):
  328. # NB: Careful not to close over self to avoid ref cycle from lru_cache
  329. entries = [
  330. (
  331. (
  332. torch.Tensor,
  333. torch.nn.Parameter,
  334. torch._subclasses.FakeTensor,
  335. torch._subclasses.functional_tensor.FunctionalTensor,
  336. ),
  337. cls.wrap_tensor,
  338. ),
  339. (
  340. (tuple, list, odict_values, collections.deque, torch.Size),
  341. cls.wrap_listlike,
  342. ),
  343. (tuple_iterator, cls.wrap_tuple_iterator),
  344. ((slice, range), cls.wrap_slice_range),
  345. (tuple(common_constant_types), cls.wrap_literal),
  346. (re.Pattern, cls.wrap_regex_pattern),
  347. ]
  348. if config.trace_numpy and np:
  349. entries.append((np.ndarray, cls.wrap_numpy_ndarray))
  350. result = {}
  351. for ts, fn in entries:
  352. for t in ts if isinstance(ts, tuple) else (ts,):
  353. assert t not in result
  354. result[t] = fn
  355. return result
  356. def wrap_regex_pattern(self, value: re.Pattern):
  357. # TODO(jansel): something like a REPR_MATCH might be more robust here
  358. self.install_guards(GuardBuilder.ID_MATCH)
  359. return RegexPatternVariable(value)
  360. @classmethod
  361. @functools.lru_cache(None)
  362. def _id_dispatch(cls):
  363. from ..comptime import comptime
  364. entries = [
  365. (
  366. inspect.signature,
  367. lambda self, value: LambdaVariable(
  368. InspectSignatureVariable.create,
  369. source=self.source,
  370. **self.install_guards(GuardBuilder.CLOSURE_MATCH),
  371. ),
  372. ),
  373. (comptime, lambda self, value: ComptimeVariable()),
  374. (
  375. dataclasses.fields,
  376. lambda self, value: LambdaVariable(
  377. _dataclasses_fields_lambda,
  378. source=self.source,
  379. **self.install_guards(GuardBuilder.FUNCTION_MATCH),
  380. ),
  381. ),
  382. (torch.__version__, lambda self, value: TorchVersionVariable()),
  383. ]
  384. result = {}
  385. for ts, fn in entries:
  386. for t in ts if isinstance(ts, (tuple, list)) else (ts,):
  387. assert t not in result
  388. result[id(t)] = fn
  389. return result
  390. def _wrap(self, value):
  391. # import here to avoid circular dependencies
  392. from torch.utils._triton import has_triton
  393. if has_triton():
  394. from triton.runtime.autotuner import Autotuner
  395. from triton.runtime.jit import JITFunction
  396. else:
  397. class JITFunction:
  398. pass
  399. class Autotuner:
  400. pass
  401. # Handle exact type() match
  402. type_dispatch = self._type_dispatch().get(type(value))
  403. if type_dispatch is not None:
  404. return type_dispatch(self, value)
  405. # Handle exact id() match
  406. id_dispatch = self._id_dispatch().get(id(value))
  407. if id_dispatch is not None:
  408. return id_dispatch(self, value)
  409. # Note - There are some nested values where types mismatch!
  410. # We want to get those out and wrap those.
  411. value = inspect.getattr_static(value, "_torchdynamo_inline", value)
  412. # Everything else (NB: order matters!)
  413. if is_traceable_wrapper_subclass(value) or istype(
  414. value, config.traceable_tensor_subclasses
  415. ):
  416. return self.wrap_tensor(value)
  417. elif is_namedtuple(value):
  418. return self.wrap_listlike(value)
  419. elif value is torch.utils._pytree.SUPPORTED_NODES:
  420. # For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
  421. # under the assumption that the values themselves don't change.
  422. self.install_guards(GuardBuilder.DICT_VERSION)
  423. # The keys on the SUPPORTED_NODES can be arbitrary, so save on the
  424. # key order.
  425. self.tx.output.guard_on_key_order.add(self.source.name())
  426. result = {
  427. ConstantVariable.create(k): UserDefinedObjectVariable(
  428. v,
  429. source=GetItemSource(
  430. self.get_source(), ConstDictKeySource(self.get_source(), i)
  431. ),
  432. )
  433. for i, (k, v) in enumerate(value.items())
  434. }
  435. return ConstDictVariable(result, type(value))
  436. elif value is sys.modules:
  437. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  438. return PythonSysModulesVariable(source=self.source)
  439. elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
  440. if not value and self.get_source().is_nn_module():
  441. # It is faster to guard on 'false' property than to guard
  442. # on actual dict keys, but we can't do this fast guard in general because
  443. # it omits a crucial type check that ensures the value is actually still a dict at runtime.
  444. # Why is this OK for (specialized) nnmodules? We set up a setattr hook
  445. # to check for module property mutations, which does a reasonable,
  446. # but not completely secure job ensuring a property wasn't changed.
  447. self.install_guards(GuardBuilder.BOOL_FALSE)
  448. else:
  449. self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
  450. # Optimisation for the common case strings, ints, etc
  451. all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
  452. if all_const:
  453. # TODO(anijain2305) - Do we have to guard on all the keys? Can
  454. # keys be guarded lazily, similar to values?
  455. self.install_guards(GuardBuilder.DICT_CONST_KEYS)
  456. else:
  457. # Guard on the key order
  458. # This is not ideal, i.e., there is no need to guard on the key
  459. # order. But we guard on the key order because of the complexity
  460. #
  461. # 1) For non-constant objects, we can't save the key in the
  462. # guard context because it can be memory heavy. We can add
  463. # weakrefs but this complicates the accesses.
  464. #
  465. # 2) For non-constant objects, we also have to guard on the keys
  466. # (like TENSOR_MATCH on tensor). We might also have guards on
  467. # the attributes of the keys (like tensor.grad). To make this
  468. # work in tree strucutre is complicated.
  469. #
  470. # So, instead we guard on the key order. While guarding on key
  471. # order, we just save the indices and use it to access keys and
  472. # values. Indices are cheap to save.
  473. self.tx.output.guard_on_key_order.add(self.source.name())
  474. # We need all the keys to be hashable. We do this within the
  475. # _HashableTracker class in dicts.py
  476. def build_key_value(i, k, v):
  477. if all_const:
  478. key = ConstantVariable.create(k)
  479. source_key = k
  480. else:
  481. source_key = ConstDictKeySource(self.get_source(), i)
  482. key = LazyVariableTracker.create(k, source_key)
  483. source_value = GetItemSource(self.get_source(), source_key)
  484. value = LazyVariableTracker.create(v, source_value)
  485. return key, value
  486. result = dict(
  487. build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
  488. )
  489. if istype(value, collections.defaultdict):
  490. factory_source = AttrSource(self.source, "default_factory")
  491. result = DefaultDictVariable(
  492. result,
  493. type(value),
  494. default_factory=VariableBuilder(self.tx, factory_source)(
  495. value.default_factory
  496. ),
  497. source=self.source,
  498. )
  499. else:
  500. result = ConstDictVariable(result, type(value), source=self.source)
  501. return self.set_source_and_track_mutable(value, result)
  502. elif isinstance(value, torch.nn.Module):
  503. return self.wrap_module(value)
  504. elif ConstantVariable.is_literal(value): # non-atomic literals
  505. return self.wrap_literal(value)
  506. elif istype(value, frozenset) and (
  507. ConstantVariable.is_literal(x) for x in value
  508. ):
  509. # For frozenset, we can guard by object ID instead of value
  510. # equality, this allows us to handle non-literal values
  511. self.install_guards(GuardBuilder.ID_MATCH)
  512. return ConstantVariable.create(value=value, source=self.source)
  513. elif isinstance(value, enum.Enum):
  514. self.install_guards(GuardBuilder.ID_MATCH)
  515. return EnumVariable(value=value, source=self.source)
  516. elif DebuggingVariable.is_reorderable_logging_function(value):
  517. # Put this above builtin_callable so that print() can be handled
  518. # along with other builtin debugging functions
  519. self.install_guards(GuardBuilder.BUILTIN_MATCH)
  520. return DebuggingVariable(value, source=self.source)
  521. elif isinstance(value, logging.Logger):
  522. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  523. return LoggingLoggerVariable(value, source=self.source)
  524. elif is_utils_checkpoint(value):
  525. return build_checkpoint_variable(source=self.source)
  526. elif isinstance(value, functools.partial):
  527. func_src = AttrSource(self.get_source(), "func")
  528. func_obj = VariableBuilder(self.tx, func_src)(value.func)
  529. args = []
  530. args_source = AttrSource(self.get_source(), "args")
  531. for i, arg in enumerate(value.args):
  532. args.append(
  533. VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
  534. )
  535. keywords = {}
  536. keywords_source = AttrSource(self.get_source(), "keywords")
  537. for k, v in value.keywords.items():
  538. if not ConstantVariable.is_literal(k):
  539. unimplemented("functools.partial with non-literal keyword")
  540. keywords[k] = VariableBuilder(
  541. self.tx, GetItemSource(keywords_source, k)
  542. )(v)
  543. install_guard(
  544. self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
  545. keywords_source.make_guard(GuardBuilder.DICT_KEYS),
  546. args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
  547. )
  548. return FunctoolsPartialVariable(func_obj, args, keywords)
  549. elif is_typing(value):
  550. # typing.List, typing.Mapping, etc.
  551. self.install_guards(GuardBuilder.ID_MATCH)
  552. return TypingVariable(
  553. value,
  554. source=self.source,
  555. )
  556. elif np is not None and isinstance(value, np.generic):
  557. # numpy array scalars: convert to 0D arrays
  558. return self.wrap_numpy_ndarray(np.asarray(value))
  559. elif is_numpy(value):
  560. assert np
  561. self.install_guards(
  562. GuardBuilder.FUNCTION_MATCH
  563. if callable(value)
  564. else GuardBuilder.TYPE_MATCH
  565. )
  566. return NumpyVariable(value, source=self.source)
  567. elif is_numpy_dtype(value):
  568. self.install_guards(GuardBuilder.ID_MATCH)
  569. return NumpyDTypeVariable(value, source=self.source)
  570. elif is_numpy_type_info(value):
  571. if isinstance(value, np.iinfo):
  572. self.install_guards(GuardBuilder.TYPE_MATCH)
  573. dt_source = AttrSource(self.source, "dtype")
  574. install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH))
  575. else:
  576. self.install_guards(GuardBuilder.ID_MATCH)
  577. return NumpyTypeInfoVariable(value, source=self.source)
  578. # NB: These can't be put in type_dispatch, they have to run later
  579. elif CollectiveFunctionRewriteVariable.can_rewrite(value):
  580. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  581. return CollectiveFunctionRewriteVariable.create(
  582. self.tx,
  583. value,
  584. source=self.source,
  585. )
  586. elif istype(value, torch.autograd.function.FunctionMeta):
  587. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  588. return AutogradFunctionVariable(
  589. value,
  590. source=self.source,
  591. )
  592. elif isinstance(value, torch.autograd.function.FunctionCtx):
  593. actual_saved_tensors = None
  594. try:
  595. actual_saved_tensors = value.saved_tensors
  596. except RuntimeError:
  597. pass
  598. saved_tensors = []
  599. guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)]
  600. if isinstance(actual_saved_tensors, tuple):
  601. saved_tensors_source = AttrSource(self.source, "saved_tensors")
  602. guards.append(
  603. saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
  604. )
  605. for i, v in enumerate(actual_saved_tensors):
  606. saved_tensors.append(
  607. VariableBuilder(
  608. self.tx, GetItemSource(saved_tensors_source, i)
  609. )(v)
  610. )
  611. install_guard(*guards)
  612. return self.tx.output.side_effects.track_object_existing(
  613. value,
  614. AutogradFunctionContextVariable(
  615. value,
  616. source=self.source,
  617. saved_tensors=SavedTensorBox(saved_tensors),
  618. ),
  619. )
  620. elif (
  621. isinstance(value, types.MethodType)
  622. and istype(
  623. getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
  624. )
  625. and getattr(value, "__name__", "") == "apply"
  626. and value == getattr(value.__self__, "apply", None)
  627. ):
  628. # handle aliased autograd function `apply` calls
  629. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  630. return GetAttrVariable(
  631. AutogradFunctionVariable(
  632. value.__self__, source=AttrSource(self.source, member="__self__")
  633. ),
  634. "apply",
  635. )
  636. elif callable(value) and trace_rules.lookup_callable(value) is not None:
  637. if is_callable_allowed(value):
  638. self.tx.output.has_user_defined_allowed_in_graph = True
  639. return trace_rules.lookup_callable(value).create_with_source(
  640. value, source=self.source
  641. )
  642. elif np and isinstance(value, np.number):
  643. return self.wrap_unspecialized_primitive(value)
  644. elif DataClassVariable.is_matching_object(value):
  645. self.install_guards(GuardBuilder.TYPE_MATCH)
  646. return DataClassVariable.wrap(self, value)
  647. elif HFPretrainedConfigVariable.is_matching_object(value):
  648. self.install_guards(GuardBuilder.TYPE_MATCH)
  649. return HFPretrainedConfigVariable(value)
  650. elif isinstance(value, HigherOrderOperator):
  651. self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
  652. return TorchHigherOrderOperatorVariable.make(value, source=self.source)
  653. elif isinstance(value, torch.cuda.StreamContext):
  654. self.install_guards(GuardBuilder.ID_MATCH)
  655. stream_source = AttrSource(self.source, "stream")
  656. stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
  657. return StreamContextVariable.create(self.tx, stream_var)
  658. elif isinstance(value, _StreamBase):
  659. self.install_guards(GuardBuilder.ID_MATCH)
  660. stream_proxy = self.tx.output.create_proxy(
  661. "call_function",
  662. torch.cuda.Stream,
  663. (),
  664. {
  665. "stream_id": value.stream_id,
  666. "device_index": value.device_index,
  667. "device_type": value.device_type,
  668. },
  669. )
  670. set_example_value(stream_proxy.node, value)
  671. return StreamVariable(
  672. stream_proxy,
  673. value,
  674. value.device,
  675. source=self.source,
  676. )
  677. elif isinstance(value, (torch._C._SDPAParams)):
  678. self.install_guards(GuardBuilder.TYPE_MATCH)
  679. return SDPAParamsVariable.create(self.tx, value, self.source)
  680. elif isinstance(value, _EventBase):
  681. self.install_guards(GuardBuilder.ID_MATCH)
  682. return EventVariable(
  683. None,
  684. value,
  685. source=self.source,
  686. )
  687. elif (
  688. isinstance(value, torch._C._TensorMeta)
  689. and value in config.traceable_tensor_subclasses
  690. ):
  691. return TensorSubclassVariable(value, source=self.source)
  692. elif (
  693. istype(value, contextlib.nullcontext)
  694. and inspect.getattr_static(value, "enter_result", None) is None
  695. ):
  696. self.install_guards(GuardBuilder.TYPE_MATCH)
  697. return NullContextVariable(source=self.source)
  698. elif KeyedJaggedTensorVariable.is_matching_object(value):
  699. self.install_guards(GuardBuilder.TYPE_MATCH)
  700. result = KeyedJaggedTensorVariable(value, source=self.source)
  701. # TODO: this doing it manually is bad
  702. return self.tx.output.side_effects.track_object_existing(value, result)
  703. elif isinstance(value, torch.optim.Optimizer):
  704. self.install_guards(GuardBuilder.ID_MATCH)
  705. self.source = OptimizerSource(self.source)
  706. return OptimizerVariable(value, source=self.source)
  707. elif WorldMetaClassVariable.is_group_member_type(value):
  708. return WorldMetaClassVariable(value, source=self.source)
  709. elif ProcessGroupVariable.is_process_group(value):
  710. self.install_guards(GuardBuilder.ID_MATCH)
  711. return ProcessGroupVariable(value, source=self.source)
  712. elif DeviceMeshVariable.is_device_mesh(value):
  713. # TODO: see if we need to add custom guard instead of a simple ID_MATCH
  714. self.install_guards(GuardBuilder.ID_MATCH)
  715. return DeviceMeshVariable(value, source=self.source)
  716. elif PlacementClassVariable.is_placement_type(value):
  717. # TODO: see if we need to add custom guard instead of a simple ID_MATCH
  718. self.install_guards(GuardBuilder.ID_MATCH)
  719. return PlacementClassVariable(value, source=self.source)
  720. elif PlacementVariable.is_placement(value):
  721. # TODO: see if we need to add custom guard instead of a simple ID_MATCH
  722. self.install_guards(GuardBuilder.ID_MATCH)
  723. return PlacementVariable(
  724. value,
  725. source=self.source,
  726. )
  727. elif istype(value, type) and value in itertools.__dict__.values():
  728. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  729. return ItertoolsVariable(value, source=self.source)
  730. elif isinstance(value, torch.SymBool):
  731. # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
  732. # user provided SymBool with a SymInt in dynamo.
  733. # Concretely,
  734. # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
  735. # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
  736. # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
  737. # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
  738. value_hint = value.node.require_hint()
  739. new_source = ConvertIntSource(self.source)
  740. new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
  741. int(value_hint),
  742. new_source,
  743. dynamic_dim=DimDynamic.DYNAMIC,
  744. )
  745. sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
  746. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  747. type(new_symint),
  748. source=new_source,
  749. )
  750. sym_node_proxy.node.meta["grapharg"] = GraphArg(
  751. new_source,
  752. new_symint,
  753. False,
  754. None,
  755. is_tensor=False,
  756. example_strong_ref=new_symint,
  757. )
  758. self.tx.output.bound_symbols.add(new_symint.node.expr)
  759. self.tx.output.tracked_fakes.append(
  760. TrackedFake(new_symint, new_source, None)
  761. )
  762. return SymNodeVariable(
  763. sym_node_proxy,
  764. new_symint == 1,
  765. )
  766. elif isinstance(value, (JITFunction, Autotuner)):
  767. self.install_guards(GuardBuilder.ID_MATCH)
  768. return TritonKernelVariable(
  769. value,
  770. None, # No kernel idx provided
  771. None, # No grid provided
  772. source=self.source,
  773. )
  774. elif isinstance(value, torch.amp.autocast_mode.autocast):
  775. self.install_guards(GuardBuilder.ID_MATCH)
  776. return AutocastModeVariable(
  777. target_values=[
  778. value.device,
  779. value.fast_dtype,
  780. value._enabled,
  781. value._cache_enabled,
  782. ],
  783. source=self.source,
  784. )
  785. elif TorchCtxManagerClassVariable.is_matching_cls(value):
  786. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  787. return TorchCtxManagerClassVariable(value, source=self.source)
  788. elif is_function_or_wrapper(value):
  789. value, attr_name = unwrap_with_attr_name_if_wrapper(value)
  790. # For these wrappers, Dynamo points to the wrapped function,
  791. # so source needs to be updated as well.
  792. if attr_name is not None:
  793. self.source = AttrSource(self.source, attr_name)
  794. return trace_rules.lookup(value).create_with_source(
  795. value, source=self.source
  796. )
  797. # Don't use istype, since some python modules are not subclasses of types.ModuleType directly.
  798. # E.g, type(torch.ops) -> <class 'torch._ops._Ops'>,
  799. # type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
  800. elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
  801. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  802. return PythonModuleVariable(
  803. value,
  804. source=self.source,
  805. )
  806. elif isinstance(value, types.MethodType) and isinstance(
  807. value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
  808. ):
  809. # don't let MethodTypes fall through to UserDefinedObject,
  810. # which doesn't support 'CALL_FUNCTION'
  811. # TODO(whc): Why do we limit this to methods on NNModules?
  812. # I don't have a good reason for this, but it preserves the existing behavior
  813. # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
  814. # I suspect we probably want to relax this check and dig deeper there.
  815. # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
  816. # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here
  817. # and then `__func__` gets wrapped inside UserMethodVariable.
  818. self_obj = VariableBuilder(
  819. self.tx, source=AttrSource(self.source, "__self__")
  820. )(value.__self__)
  821. assert self_obj and isinstance(
  822. self_obj, VariableTracker
  823. ), "Failed to produce a valid self obj"
  824. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  825. return UserMethodVariable(
  826. value.__func__,
  827. self_obj,
  828. source=self.source,
  829. )
  830. elif isinstance(value, types.GetSetDescriptorType):
  831. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  832. return GetSetDescriptorVariable(value)
  833. elif isinstance(value, types.MethodWrapperType):
  834. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  835. return MethodWrapperVariable(value)
  836. elif issubclass(type(value), type):
  837. if value in (torch.utils.hooks.BackwardHook, torch.nn.Parameter):
  838. # TODO(jansel): combine this case with the one above
  839. return trace_rules.lookup(value).create_with_source(
  840. value, source=self.source
  841. )
  842. if value is torch.autograd._unsafe_preserve_version_counter:
  843. self.install_guards(GuardBuilder.FUNCTION_MATCH)
  844. return PreserveVersionContextVariable.constructor(self.tx)
  845. # This is a userdefined class, so install an ID_MATCH even if its a
  846. # global variable.
  847. self.install_guards(GuardBuilder.ID_MATCH)
  848. return UserDefinedClassVariable(
  849. value,
  850. source=self.source,
  851. )
  852. elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
  853. self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
  854. return self.set_source_and_track_mutable(
  855. value,
  856. RestrictedListSubclassVariable(
  857. [
  858. LazyVariableTracker.create(
  859. value=value[i], source=GetItemSource(self.source, i)
  860. )
  861. for i in range(len(value))
  862. ],
  863. user_cls=type(value),
  864. user_cls_source=AttrSource(self.source, "__class__"),
  865. ),
  866. )
  867. elif TorchScriptObjectVariable.is_matching_cls(type(value)):
  868. from ..source import (
  869. FlattenScriptObjectSource,
  870. ScriptObjectQualifiedNameSource,
  871. )
  872. # This exists to allow a smoother transition.
  873. # The implications are:
  874. # The script objects won't be tracked as proxies.
  875. # Methods on these objects won't show up in the graph.
  876. # The original script object might be mutated.
  877. if not hasattr(value, "__obj_flatten__"):
  878. return self.wrap_user_defined(value)
  879. # Install the guards on the fully qualified name of the script object
  880. LazyVariableTracker.realize_all(
  881. VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
  882. value._type().qualified_name() # type: ignore[attr-defined]
  883. )
  884. )
  885. # Install the guards on the content of the script object by setting the source
  886. # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
  887. LazyVariableTracker.realize_all(
  888. VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
  889. value.__obj_flatten__()
  890. )
  891. )
  892. fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
  893. self.tx.output.fake_mode, value
  894. )
  895. proxy = self.tx.output.root_tracer.create_graph_input(
  896. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  897. type(value),
  898. source=self.source,
  899. )
  900. # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
  901. # seting example to be real value because these example values will be used
  902. # as example_inputs for user compiler.
  903. proxy.node.meta["grapharg"] = GraphArg(
  904. self.source, value, False, None, False, fake_script_obj
  905. )
  906. return TorchScriptObjectVariable.create(
  907. proxy,
  908. fake_script_obj,
  909. source=self.source,
  910. )
  911. else:
  912. return self.wrap_user_defined(value)
  913. def wrap_user_defined(self, value: Any):
  914. self.install_guards(GuardBuilder.TYPE_MATCH)
  915. result = UserDefinedObjectVariable(value, source=self.source)
  916. if not SideEffects.cls_supports_mutation_side_effects(type(value)):
  917. # don't allow STORE_ATTR mutation with custom __setattr__
  918. return result
  919. return self.tx.output.side_effects.track_object_existing(value, result)
  920. def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
  921. if config.specialize_int and type(value) is torch.Size:
  922. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  923. return ConstantVariable.create(value=value)
  924. # One can index a tensor with a list/tuple. Therefore, we need to
  925. # have a stricter match.
  926. self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
  927. for item in value:
  928. if item is value:
  929. unimplemented("list elements are pointing to the list itself")
  930. output = [
  931. LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i))
  932. for i, item in enumerate(value)
  933. ]
  934. maybe_gm = self.tx.output.local_scope.get("self")
  935. if isinstance(
  936. self.source, LocalSource
  937. ) and self.source.local_name in get_locals_to_steal(maybe_gm):
  938. # The input tensor list to dynamo from compiled autograd may contain activations
  939. # which are freed as they are used in inductor. Dynamo's default behavior is to
  940. # lift all tensors to the graph inputs, but this will cause dynamo to hold an
  941. # extra reference to the activation tensors and increase peak memory usage.
  942. # To allow freeing ASAP, we keep the list as graph argument to the dynamo output
  943. # graph, and unpack it locally.
  944. # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have
  945. # `def forward(self, L_inputs_):`
  946. source = self.source
  947. assert isinstance(value, list)
  948. tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
  949. re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
  950. )
  951. tensor_list_proxy.node.meta["steal_arg"] = True
  952. list_variable = wrap_fx_proxy_cls(
  953. target_cls=TensorVariable,
  954. tx=self.tx,
  955. proxy=tensor_list_proxy,
  956. example_value=value,
  957. subclass_type=None,
  958. source=source,
  959. )
  960. guards = []
  961. for i, tensor_variable in enumerate(list_variable.items):
  962. source_i = GetItemSource(base=source, index=i, index_is_slice=False)
  963. # access unpacked tensor from this list instead of from a lifted arg
  964. self.tx.output.input_source_to_var[source_i] = tensor_variable
  965. guard = functools.partial(
  966. GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
  967. )
  968. guards.append(source_i.make_guard(guard))
  969. install_guard(*guards, skip=1)
  970. grapharg = GraphArg(
  971. source,
  972. value,
  973. pass_arg_as_tensor=False,
  974. fake_tensor=None,
  975. is_tensor=False,
  976. )
  977. tensor_list_proxy.node.meta["grapharg"] = grapharg
  978. result = BaseListVariable.cls_for_instance(value)(
  979. output, mutable_local=MutableLocal()
  980. )
  981. if istype(value, list):
  982. return self.set_source_and_track_mutable(value, result)
  983. return result
  984. def wrap_tuple_iterator(self, value: tuple_iterator):
  985. self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
  986. output = [
  987. VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
  988. tuple_iterator_getitem(value, i)
  989. )
  990. for i in range(tuple_iterator_len(value))
  991. ]
  992. result = TupleIteratorVariable(
  993. output, mutable_local=MutableLocal(), source=self.source
  994. )
  995. return self.set_source_and_track_mutable(value, result)
  996. def wrap_slice_range(self, value: Union[slice, range]):
  997. items = [
  998. VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
  999. getattr(value, k)
  1000. )
  1001. for k in ("start", "stop", "step")
  1002. ]
  1003. self.install_guards(GuardBuilder.TYPE_MATCH)
  1004. if isinstance(value, slice):
  1005. return SliceVariable(items, source=self.source)
  1006. else:
  1007. return RangeVariable(items, source=self.source)
  1008. def wrap_module(self, value: torch.nn.Module):
  1009. from ..eval_frame import OptimizedModule
  1010. if len(value.__dict__) == 0:
  1011. unimplemented(f"uninitialized nn.Module: {typestr(value)}")
  1012. if istype(value, OptimizedModule):
  1013. # Check if the optimized module was disabled
  1014. if inspect.getattr_static(value.forward, "_torchdynamo_disable", False):
  1015. # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If
  1016. # we graph break here, Dynamo does not know how to create
  1017. # continuation functions for such bytecodes. So, we delay the
  1018. # graph break to CALL_FUNCTION.
  1019. return DelayGraphBreakVariable(source=self.source)
  1020. self.install_guards(GuardBuilder.TYPE_MATCH)
  1021. self.source = AttrSource(self.source, "_orig_mod")
  1022. return self.wrap_module(value._orig_mod)
  1023. if (
  1024. isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
  1025. and not config.allow_rnn
  1026. ):
  1027. unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
  1028. if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
  1029. # created dynamically, don't specialize on it
  1030. self.install_guards(GuardBuilder.TYPE_MATCH)
  1031. result = UnspecializedNNModuleVariable(value, source=self.source)
  1032. if not SideEffects.cls_supports_mutation_side_effects(type(value)):
  1033. # don't allow STORE_ATTR mutation with custom __setattr__
  1034. return result
  1035. return self.tx.output.side_effects.track_object_existing(value, result)
  1036. elif issubclass(
  1037. value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
  1038. ):
  1039. self.install_guards(GuardBuilder.TYPE_MATCH)
  1040. return UnspecializedNNModuleVariable(value)
  1041. elif getattr(value, "_is_fsdp_managed_module", False):
  1042. # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
  1043. # in fully_sharded_data_parallel.py for more information
  1044. # we can't do this assert inside FSDP constructor,
  1045. # since we don't know yet whether dynamo will be used
  1046. assert getattr(
  1047. value, "_fsdp_use_orig_params", False
  1048. ), "Dynamo only supports FSDP with use_orig_params=True"
  1049. # Note on FSDP guarding
  1050. # 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap).
  1051. # 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their
  1052. # model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams.
  1053. #
  1054. # Due to (1), once we enter this path we expect not to go back nor have to guard on type
  1055. # or _is_fsdp_managed_module.
  1056. #
  1057. # TODO(whc) We could add a guard on the opposite case, where a user compiled/ran
  1058. # pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling.
  1059. #
  1060. # Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the
  1061. # guard source. This behavior is gated on config.skip_fsdp_guards.
  1062. #
  1063. # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
  1064. # them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
  1065. self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
  1066. return FSDPManagedNNModuleVariable(value, source=self.get_source())
  1067. else:
  1068. return self.tx.output.register_attr_or_module(
  1069. value,
  1070. self.name,
  1071. source=self.get_source(),
  1072. # Guards are added inside register_attr_or_module
  1073. )
  1074. def wrap_literal(self, value):
  1075. if not config.specialize_int and type(value) is int:
  1076. # unspecializing int by default, but still
  1077. # specialize for the following conditions
  1078. if not TracingContext.get().force_unspec_int_unbacked_size_like and (
  1079. value in self._common_constants()
  1080. # Assume integers from global variables want to be specialized
  1081. or not self.source.guard_source().is_local()
  1082. or is_from_defaults(self.source)
  1083. or is_cell_contents(self.source)
  1084. ):
  1085. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1086. return ConstantVariable.create(value=value, source=self.source)
  1087. else:
  1088. return self.wrap_symint(value)
  1089. elif not config.specialize_float and type(value) is float:
  1090. return self.wrap_symfloat(value)
  1091. else:
  1092. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1093. return ConstantVariable.create(value=value)
  1094. def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
  1095. if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
  1096. raise InternalTorchDynamoError(
  1097. "Cannot wrap a Tensor that has already been",
  1098. "wrapped by this instance of Dynamo",
  1099. )
  1100. def wrap_tensor(self, value: torch.Tensor):
  1101. source = self.get_source()
  1102. # We cannot already be tracking the tensor, which implies
  1103. # it would have already been wrapped
  1104. assert value not in self.tx.output.side_effects
  1105. if (
  1106. source.guard_source().is_nn_module()
  1107. or get_static_address_type(value) is not None
  1108. ) and not source.guard_source().is_fsdp_module():
  1109. self.assert_not_wrapped_by_this_graph(value)
  1110. return self.tx.output.register_attr_or_module(
  1111. value, self.name, source=source
  1112. )
  1113. if is_constant_source(source):
  1114. self.assert_not_wrapped_by_this_graph(value)
  1115. return self.tx.output.register_attr_or_module(
  1116. value,
  1117. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  1118. source=source,
  1119. # Guards are added inside register_attr_or_module
  1120. )
  1121. if type(value) in config.traceable_tensor_subclasses:
  1122. # Ordinarily, we would fakeify a tensor so that it can get dynamic
  1123. # shapes and be computed on without triggering actual operations.
  1124. # However, how can we fakeify a tensor subclass? Ordinary
  1125. # inheritance (nor multiple inheritance) won't work work.
  1126. #
  1127. # Instead, our plan is to *manually simulate* the tensor subclass
  1128. # inheriting from a fake tensor with dynamo. This means our
  1129. # data representation for a tensor subclass will be a fake tensor
  1130. # + tensor subclass type + any extra data the subclass may have
  1131. # been storing on the tensor. Because all Python accesses are
  1132. # mediated through TensorWithTFOverrideVariable, we can ensure
  1133. # that we dispatch differently, e.g., according to
  1134. # __torch_function__
  1135. #
  1136. # To simplify things for now, the __dict__ tracking bits haven't
  1137. # been implemented yet, but they can be added into this design at
  1138. # a later point in time.
  1139. subclass_type = type(value)
  1140. else:
  1141. assert type(value) in (
  1142. torch.Tensor,
  1143. torch.nn.Parameter,
  1144. torch._subclasses.fake_tensor.FakeTensor,
  1145. torch._subclasses.functional_tensor.FunctionalTensor,
  1146. ) or is_traceable_wrapper_subclass(value), type(value)
  1147. subclass_type = None
  1148. # NB: this just says we accessed a tensor from the same source again
  1149. # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
  1150. # This is distinct from two distinct sources mapping to the same
  1151. # Tensor (per id())! No guard is necessary here. See below for the
  1152. # other case.
  1153. is_duplicate_tensor = source in self.tx.output.input_source_to_var
  1154. if is_duplicate_tensor:
  1155. return self.tx.output.input_source_to_var[source]
  1156. # By this point, we should have deduplicated all tensors
  1157. self.assert_not_wrapped_by_this_graph(value)
  1158. # tx.output has multiple tracers if we're introspecting HigherOrderOperator.
  1159. # When we've discovered an untracked tensor, then we actually need
  1160. # to get Dynamo to track the tensor (which is what this function does)
  1161. # and put it as a graph input on the root tracer. Later on,
  1162. # if the input is actually used in the body of the HigherOrderOperator,
  1163. # then the relevant SubgraphTracer will lift it to being an input of
  1164. # the subgraph.
  1165. # See NOTE [HigherOrderOperator tracing design] for more details.
  1166. tensor_proxy = self.tx.output.root_tracer.create_graph_input(
  1167. re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
  1168. )
  1169. options = {}
  1170. if type(value) in config.traceable_tensor_subclasses:
  1171. options["torch_function_fn"] = build_torch_function_fn(
  1172. self.tx, value, self.source
  1173. )
  1174. self.install_guards(GuardBuilder.TYPE_MATCH)
  1175. if (
  1176. isinstance(value, torch.Tensor)
  1177. and value.is_nested
  1178. and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
  1179. ):
  1180. unimplemented("torch.compile does not support strided NestedTensor")
  1181. # Reject sparse, but not coo.
  1182. # TODO: remove this altogether when non-coo sparsity propagation is ready
  1183. if is_sparse_any(value) and not value.is_sparse:
  1184. unimplemented(
  1185. f"torch.compile does not support sparse Tensor with {value.layout} layout"
  1186. )
  1187. tensor_variable = wrap_fx_proxy(
  1188. tx=self.tx,
  1189. proxy=tensor_proxy,
  1190. example_value=value,
  1191. subclass_type=subclass_type,
  1192. source=source,
  1193. **options,
  1194. )
  1195. guard_type = GuardBuilder.TENSOR_MATCH
  1196. if isinstance(source, GradSource) and is_from_optimizer_source(source):
  1197. guard_type = GuardBuilder.NOT_NONE_MATCH
  1198. self.install_guards(
  1199. functools.partial(
  1200. guard_type,
  1201. value=value
  1202. if isinstance(source, NumpyTensorSource)
  1203. else TensorWeakRef(value),
  1204. )
  1205. )
  1206. # We install TYPE_MATCH guards for traceable wrapper subclass object,
  1207. # and recursively install corresponding guard for each inner attribute.
  1208. if is_traceable_wrapper_subclass(value):
  1209. self.install_guards(GuardBuilder.TYPE_MATCH)
  1210. attrs, _ = value.__tensor_flatten__()
  1211. for attr in attrs:
  1212. inner_value = getattr(value, attr)
  1213. inner_source = AttrSource(self.source, attr)
  1214. LazyVariableTracker.realize_all(
  1215. VariableBuilder(self.tx, inner_source)(inner_value)
  1216. )
  1217. self.tx.output.input_source_to_var[source] = tensor_variable
  1218. assert "tensor_dict" not in tensor_proxy.node.meta
  1219. tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
  1220. # Note: this information is conveyed via subclass_type now
  1221. fake_tensor_value = tensor_variable.proxy.node.meta["example_value"]
  1222. if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode:
  1223. raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake")
  1224. grapharg = GraphArg(source, value, False, fake_tensor_value)
  1225. tensor_proxy.node.meta["grapharg"] = grapharg
  1226. self.tx.output.add_symbol_bindings(grapharg)
  1227. return tensor_variable
  1228. def wrap_numpy_ndarray(self, value):
  1229. assert np is not None
  1230. assert isinstance(value, np.ndarray)
  1231. source = NumpyTensorSource(self.get_source())
  1232. from torch._numpy import _util
  1233. readonly = not value.flags.writeable
  1234. if readonly:
  1235. try:
  1236. value.flags.writeable = True
  1237. except ValueError:
  1238. # One can not easily make nditer elements writable,
  1239. # but warning is not the end of the world
  1240. assert isinstance(value.base, np.nditer)
  1241. pass
  1242. try:
  1243. tensor_value = _util._try_convert_to_tensor(value)
  1244. if readonly:
  1245. from torch._prims_common import clone_preserve_strides
  1246. tensor_value = clone_preserve_strides(tensor_value)
  1247. except NotImplementedError as e:
  1248. # failed to convert to tensor, graph break
  1249. unimplemented(str(e))
  1250. # We do this because we want the full behavior of guarding the numpy ndarray as if it were
  1251. # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
  1252. # that there's not another great way to do this atm.
  1253. # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
  1254. LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
  1255. proxy = self.tx.output.root_tracer.create_graph_input(
  1256. re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
  1257. )
  1258. options = {"source": source}
  1259. numpy_ndarray_variable = wrap_fx_proxy_cls(
  1260. target_cls=NumpyNdarrayVariable,
  1261. tx=self.tx,
  1262. proxy=proxy,
  1263. example_value=tensor_value,
  1264. **options,
  1265. )
  1266. self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
  1267. example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
  1268. # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
  1269. # converted to a tensor.
  1270. grapharg = GraphArg(
  1271. source,
  1272. tensor_value,
  1273. pass_arg_as_tensor=True,
  1274. fake_tensor=example_value,
  1275. is_tensor=True,
  1276. example_strong_ref=tensor_value,
  1277. )
  1278. proxy.node.meta["grapharg"] = grapharg
  1279. return numpy_ndarray_variable
  1280. def wrap_symint(self, value):
  1281. assert type(value) is int
  1282. if self.name in self.tx.output.unspec_variable_map:
  1283. return self.tx.output.unspec_variable_map[self.name]
  1284. shape_env = self.tx.output.shape_env
  1285. if TracingContext.get().force_unspec_int_unbacked_size_like:
  1286. wrapped_value = shape_env.create_unbacked_symint()
  1287. _constrain_range_for_size(wrapped_value)
  1288. self.tx.output.bound_symbols.add(wrapped_value.node.expr)
  1289. self.tx.output.tracked_fakes.append(
  1290. TrackedFake(wrapped_value, self.source, None)
  1291. )
  1292. # NB: We do not do float. For motivation, see
  1293. # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
  1294. # but the general idea is that we generate kernels that can
  1295. # take unspecialized floats and use them in sizevar computation
  1296. elif not is_constant_source(self.get_source()):
  1297. if torch._dynamo.config.specialize_int:
  1298. # If specialize_int is False, also return
  1299. # a constant (but this should have been handled
  1300. # in the caller, TBH)
  1301. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1302. return ConstantVariable.create(value=value, source=self.source)
  1303. name = self.source.name()
  1304. if name not in self.tx.output.frame_state:
  1305. # Note - this essentially means that if this name gets reused as a tensor,
  1306. # it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
  1307. # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
  1308. # sure that is necessary for now.
  1309. frame_state_entry = FrameStateSizeEntry(scalar=value, size=None)
  1310. else:
  1311. frame_state_entry = self.tx.output.frame_state[name]
  1312. if frame_state_entry.scalar != value:
  1313. log.debug(
  1314. "automatic dynamic int %s val %s != %s",
  1315. name,
  1316. value,
  1317. frame_state_entry.scalar,
  1318. )
  1319. frame_state_entry.scalar = None
  1320. self.tx.output.frame_state[name] = frame_state_entry
  1321. # TODO: This should be dynamic, as we in general do not
  1322. # know if bare integers are actually going to be sizevars
  1323. # and it is inappropriate to eagerly duck size them with
  1324. # real sizevars
  1325. if (
  1326. config.automatic_dynamic_shapes and frame_state_entry.scalar is None
  1327. ) or not config.assume_static_by_default:
  1328. dynamic_dim = DimDynamic.DYNAMIC
  1329. else: # assume_static_by_default
  1330. # TODO: dynamic_dim = DimDynamic.STATIC should work but
  1331. # for some reason it doesn't
  1332. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1333. return ConstantVariable.create(value=value)
  1334. wrapped_value = shape_env.create_unspecified_symint_and_symbol(
  1335. value,
  1336. source=self.source,
  1337. dynamic_dim=dynamic_dim,
  1338. )
  1339. self.tx.output.bound_symbols.add(wrapped_value.node.expr)
  1340. self.tx.output.tracked_fakes.append(
  1341. TrackedFake(wrapped_value, self.source, None)
  1342. )
  1343. else:
  1344. assert is_constant_source(self.get_source())
  1345. # TODO: Do I actually need guard for constant source?
  1346. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1347. return ConstantVariable.create(value=value, source=self.source)
  1348. assert not isinstance(self.get_source(), RandomValueSource)
  1349. install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
  1350. options = {"source": self.get_source()}
  1351. proxy = self.tx.output.root_tracer.create_graph_input(
  1352. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  1353. type(wrapped_value),
  1354. source=self.get_source(),
  1355. )
  1356. set_example_value(proxy.node, wrapped_value)
  1357. unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
  1358. self.tx.output.unspec_variable_map[self.name] = unspec_var
  1359. if not is_constant_source(self.get_source()):
  1360. if self.tx.export and not isinstance(self.get_source(), LocalSource):
  1361. raise AssertionError(
  1362. f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
  1363. )
  1364. example_value = unspec_var.proxy.node.meta["example_value"]
  1365. proxy.node.meta["grapharg"] = GraphArg(
  1366. self.get_source(),
  1367. wrapped_value,
  1368. pass_arg_as_tensor=False,
  1369. fake_tensor=None,
  1370. is_tensor=False,
  1371. example_strong_ref=wrapped_value,
  1372. )
  1373. return unspec_var
  1374. def wrap_symfloat(self, value):
  1375. # SymFloat wrapping is special. We first wrap it in the same way we
  1376. # do an unspecialized primitive, and then we item() it into a
  1377. # SymFloat. Removal of the item() call is left to a later FX pass,
  1378. # mostly because that pass is more easily done after we have lowered
  1379. # to ATen ops. (Dynamo doesn't do decomposition right now).
  1380. if self.name in self.tx.output.unspec_variable_map:
  1381. return self.tx.output.unspec_variable_map[self.name]
  1382. # NB: we specialize on nan input, because our guard modeling in
  1383. # ShapeEnv cannot deal with nan
  1384. if (
  1385. torch._dynamo.config.specialize_float
  1386. or is_constant_source(self.get_source())
  1387. or math.isnan(value)
  1388. ):
  1389. self.install_guards(GuardBuilder.CONSTANT_MATCH)
  1390. return ConstantVariable.create(value=value, source=self.source)
  1391. # NB: At the point we've gotten here, we don't assume static by
  1392. # default. Since we have a guard mechanism, there isn't really any
  1393. # downside to trying to be dynamic for float all the time. Unlike
  1394. # ints, this won't make codegen perf worse. Modest cost to compile
  1395. # time.
  1396. wrapped_value = torch.tensor(value, dtype=torch.float64)
  1397. # TODO: Switch RandomValueSource over to use this, this is more
  1398. # accurate
  1399. assert not isinstance(self.get_source(), RandomValueSource)
  1400. install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
  1401. # The FloatTensorSource here is just for pedantic correctness: if you
  1402. # guard against an UnspecializedPythonVariable, you need to guard
  1403. # against the tensor-ified version of the local, otherwise it's not a
  1404. # Tensor. However, we never let the UnspecializedPythonVariable escape
  1405. # here, so there should never actually be any guards against this
  1406. # source.
  1407. options = {"source": FloatTensorSource(self.get_source()), "raw_value": value}
  1408. # TODO: Maybe the tensor-ification should be built into the source,
  1409. # rather than by special pattern match
  1410. proxy = self.tx.output.root_tracer.create_graph_input(
  1411. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  1412. type(wrapped_value),
  1413. source=self.get_source(),
  1414. )
  1415. unspec_var = wrap_fx_proxy_cls(
  1416. UnspecializedPythonVariable,
  1417. tx=self.tx,
  1418. proxy=proxy,
  1419. example_value=wrapped_value,
  1420. **options,
  1421. )
  1422. assert isinstance(unspec_var, UnspecializedPythonVariable)
  1423. self.tx.output.unspec_variable_map[self.name] = unspec_var
  1424. if self.tx.export and not isinstance(self.get_source(), LocalSource):
  1425. raise AssertionError(
  1426. f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
  1427. )
  1428. fake_tensor_value = None
  1429. example_value = unspec_var.proxy.node.meta["example_value"]
  1430. assert is_fake(example_value)
  1431. fake_tensor_value = example_value
  1432. assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
  1433. f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
  1434. "({self.tx.fake_mode}) from InstructionTranslator"
  1435. )
  1436. # There's something a bit incoherent about pass_arg_as_tensor,
  1437. # specifically regarding sources.
  1438. #
  1439. # Specifically, suppose we have "x: float" local argument. We
  1440. # eventually end up with an UnspecializedPythonVariable denoting
  1441. # torch.as_tensor(x)... but it's source is still L['x'] (which if you
  1442. # accessed it directly is a float!) So you gotta be careful when
  1443. # setting up your guards, because it's still going to be a float at
  1444. # this point, the conversion happens only precisely at the point we're
  1445. # actually calling the FX graph. This happens to be what we want for
  1446. # shape guard generation, but it's kind of unintuitive.
  1447. proxy.node.meta["grapharg"] = GraphArg(
  1448. self.get_source(),
  1449. wrapped_value,
  1450. pass_arg_as_tensor=True,
  1451. fake_tensor=fake_tensor_value,
  1452. is_tensor=False,
  1453. example_strong_ref=wrapped_value,
  1454. )
  1455. # Directly do item to bypass capture_scalar_outputs
  1456. r = wrap_fx_proxy(
  1457. self.tx,
  1458. self.tx.output.create_proxy(
  1459. "call_method",
  1460. "item",
  1461. *proxy_args_kwargs([unspec_var], {}),
  1462. ),
  1463. )
  1464. self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None))
  1465. return r
  1466. def wrap_unspecialized_primitive(self, value):
  1467. if self.name in self.tx.output.unspec_variable_map:
  1468. return self.tx.output.unspec_variable_map[self.name]
  1469. wrapped_value = torch.tensor(value)
  1470. if not isinstance(self.get_source(), RandomValueSource):
  1471. install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
  1472. options = {"source": self.get_source()}
  1473. options.update({"raw_value": value})
  1474. proxy = self.tx.output.root_tracer.create_graph_input(
  1475. re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
  1476. type(wrapped_value),
  1477. source=self.get_source(),
  1478. )
  1479. unspec_var = wrap_fx_proxy_cls(
  1480. UnspecializedPythonVariable,
  1481. tx=self.tx,
  1482. proxy=proxy,
  1483. example_value=wrapped_value,
  1484. **options,
  1485. )
  1486. self.tx.output.unspec_variable_map[self.name] = unspec_var
  1487. if not is_constant_source(self.get_source()):
  1488. if self.tx.export and not isinstance(self.get_source(), LocalSource):
  1489. raise AssertionError(
  1490. f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
  1491. )
  1492. fake_tensor_value = None
  1493. if isinstance(unspec_var, ConstantVariable):
  1494. # TODO: when can this happen?
  1495. example_value = unspec_var.value
  1496. else:
  1497. example_value = unspec_var.proxy.node.meta["example_value"]
  1498. assert is_fake(example_value)
  1499. fake_tensor_value = example_value
  1500. assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
  1501. f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
  1502. "({self.tx.fake_mode}) from InstructionTranslator"
  1503. )
  1504. proxy.node.meta["grapharg"] = GraphArg(
  1505. self.get_source(),
  1506. wrapped_value,
  1507. pass_arg_as_tensor=True,
  1508. fake_tensor=fake_tensor_value,
  1509. is_tensor=False,
  1510. example_strong_ref=wrapped_value,
  1511. )
  1512. return unspec_var
  1513. def _dataclasses_fields_lambda(obj):
  1514. if isinstance(obj, UserDefinedObjectVariable):
  1515. value = obj.value
  1516. elif isinstance(obj, DataClassVariable):
  1517. value = obj.user_cls
  1518. else:
  1519. unimplemented(f"Dataclass fields handling fails for type {obj}")
  1520. items = []
  1521. for field in dataclasses.fields(value):
  1522. source = None
  1523. if obj.source:
  1524. source = GetItemSource(
  1525. AttrSource(obj.source, "__dataclass_fields__"), field.name
  1526. )
  1527. items.append(UserDefinedObjectVariable(field, source=source))
  1528. return TupleVariable(items)
  1529. def wrap_fx_proxy(
  1530. tx, proxy, example_value=None, subclass_type=None, **options
  1531. ) -> VariableTracker:
  1532. kwargs = {
  1533. "tx": tx,
  1534. "proxy": proxy,
  1535. "example_value": example_value,
  1536. "subclass_type": subclass_type,
  1537. **options,
  1538. }
  1539. if subclass_type is None:
  1540. return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  1541. else:
  1542. result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
  1543. result.install_global(tx)
  1544. return result
  1545. # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
  1546. # Should be compositional instead
  1547. #
  1548. # This is a horribly complicated function that does too many things, to
  1549. # explain what it does, let's first talk about the classic usage wrap_fx_proxy
  1550. # for a TensorVariable. There are two primary modes of use:
  1551. #
  1552. # 1. Wrapping a pre-existing Tensor. In this case, example_value is set
  1553. # to the pre-existing Tensor. (Note that this example_value will NOT
  1554. # be the final example_value we put into node.meta['example_value'],
  1555. # instead it is converted into a fake tensor using
  1556. # wrap_to_fake_tensor_and_record and registered as a graph input.)
  1557. #
  1558. # 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In
  1559. # this case, example_value is None (and we are going to figure it out
  1560. # ourselves using FakeTensors, via get_fake_value, which will run
  1561. # the operation represented by the (singular!) FX node referenced by
  1562. # the passed in proxy.)
  1563. #
  1564. # The expectation is you end up with a Tensor output, and everything is
  1565. # straightforwardly traced into the graph.
  1566. #
  1567. # In all cases, the returned `TensorVariable` subclass will have an `example_value`
  1568. # and that `example_value` must be a `FakeTensor` produced by the currently running
  1569. # instance of Dynamo.
  1570. #
  1571. # Upon closer inspection, you may notice that there are a slurry of non-Tensor
  1572. # output cases. What gives? Well, we sometimes trace operations into the
  1573. # graph that don't involve tensors.
  1574. #
  1575. # * Some operators return tuples; we need to recursively handle their
  1576. # contents
  1577. #
  1578. # * Some operators have side effects that will affect subsequent AOTAutograd
  1579. # tracing but don't otherwise return anything.
  1580. #
  1581. # * Some operators return symbolic ints/floats/bools which can go in the
  1582. # graph and be traced (but only if they're actually symbolic! If they're
  1583. # static you don't want to put them in the graph, which means you
  1584. # shouldn't call this function.)
  1585. #
  1586. # The common theme is that you only use this function WHEN YOU ARE TRACING
  1587. # SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call
  1588. # this function without a proxy.
  1589. def wrap_fx_proxy_cls(
  1590. target_cls, tx, proxy, example_value=None, subclass_type=None, **options
  1591. ):
  1592. from ..symbolic_convert import InstructionTranslatorBase
  1593. assert isinstance(tx, InstructionTranslatorBase)
  1594. if "guards" in options and options["guards"] is not None:
  1595. tx.output.guards.update(options["guards"])
  1596. assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
  1597. initial_example_value = example_value
  1598. def _clone_input(value):
  1599. if isinstance(value, torch.Tensor):
  1600. # tensor subclasses will not be converted to FakeTensors and need to be cloned
  1601. if not (
  1602. isinstance(value, FakeTensor)
  1603. or (
  1604. # Is functional tensor fakeified by this instance of Dynamo
  1605. torch._is_functional_tensor(value)
  1606. and maybe_get_fake_mode(value) is tx.fake_mode
  1607. )
  1608. or value.is_nested
  1609. ):
  1610. # NB: ensure strides are preserved
  1611. value = clone_input(value)
  1612. return value
  1613. # with preserve_rng_state():
  1614. if example_value is None:
  1615. # only allow_non_graph_fake in this instance because we handle the non-fake
  1616. # cases properly below.
  1617. example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  1618. # Handle recursive calls here
  1619. elif maybe_get_fake_mode(example_value) is tx.fake_mode:
  1620. pass
  1621. elif isinstance(example_value, torch.Tensor):
  1622. if tx.export:
  1623. # The legacy behavior for real value cache with subclasses was
  1624. # to perform a clone WITHOUT preserving the subclass. It's
  1625. # not entirely clear this is what you actually want though.
  1626. with torch._C.DisableTorchFunctionSubclass():
  1627. proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
  1628. # NB: If we're ignoring subclass, then the expectation is you will
  1629. # take the returned TensorVariable and wrap it into a more
  1630. # accurate TensorVariable that is able to track subclass-ness;
  1631. # otherwise this is wrong!
  1632. kwargs = {
  1633. "is_tensor": target_cls in (TensorVariable, TensorWithTFOverrideVariable),
  1634. }
  1635. assert "source" in options and options["source"] is not None
  1636. kwargs["source"] = options["source"]
  1637. example_value = wrap_to_fake_tensor_and_record(example_value, tx=tx, **kwargs)
  1638. if isinstance(example_value, torch.Tensor) and (
  1639. maybe_get_fake_mode(example_value) is not tx.fake_mode
  1640. ):
  1641. raise InternalTorchDynamoError(
  1642. "`example_value` needs to be a `FakeTensor`"
  1643. f"wrapped by this instance of Dynamo. Found: {example_value}"
  1644. )
  1645. if isinstance(example_value, torch.Tensor):
  1646. is_parameter = isinstance(example_value, torch.nn.Parameter)
  1647. # NB: In most (all?) cases, this does not actually do a clone.
  1648. # (WARNING: this means that if we mutate metadata on the fake
  1649. # tensor, the stored example value will update too!)
  1650. example_value = _clone_input(example_value)
  1651. set_example_value(proxy.node, example_value)
  1652. specialized_props = target_cls.specialize(example_value)
  1653. # TODO: not sure about this fake mode test
  1654. if (
  1655. isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
  1656. and example_value.fake_mode is tx.fake_mode
  1657. ):
  1658. tensor_type = subclass_type if subclass_type else torch.Tensor
  1659. specialized_props["class_type"] = (
  1660. torch.nn.Parameter if is_parameter else tensor_type
  1661. )
  1662. options.update(specialized_props)
  1663. return target_cls(proxy, **options)
  1664. elif (
  1665. hasattr(proxy.node.target, "__name__")
  1666. and proxy.node.target.__name__ == "set_state"
  1667. and isinstance(proxy.node.target.__self__, torch._C.Generator)
  1668. or proxy.node.target == torch.random.set_rng_state
  1669. ):
  1670. return TorchInGraphFunctionVariable(proxy.node.target)
  1671. elif (
  1672. proxy.node.target == torch._C._DisableFuncTorch
  1673. or proxy.node.target == torch.cuda._is_in_bad_fork
  1674. ):
  1675. return UserDefinedObjectVariable(example_value)
  1676. elif istype(example_value, torch.Size) and all(
  1677. isinstance(x, int) for x in example_value
  1678. ):
  1679. sizes = [ConstantVariable.create(x) for x in example_value]
  1680. return SizeVariable(sizes, **options)
  1681. elif isinstance(example_value, (tuple, list)):
  1682. set_example_value(proxy.node, example_value)
  1683. unpacked = []
  1684. for i, val in enumerate(example_value):
  1685. if val is None:
  1686. # nn.MultiheadAttention() can return None, see issue #175
  1687. unpacked.append(
  1688. ConstantVariable.create(None, **options),
  1689. )
  1690. else:
  1691. proxy_i = proxy.tracer.create_proxy(
  1692. kind="call_function",
  1693. target=operator.getitem,
  1694. args=(proxy, i),
  1695. kwargs={},
  1696. )
  1697. if "source" in options:
  1698. source = options["source"]
  1699. options_i = options.copy()
  1700. options_i["source"] = GetItemSource(
  1701. base=source, index=i, index_is_slice=False
  1702. )
  1703. else:
  1704. # use the same options object as parent
  1705. options_i = options
  1706. # WARNING: this assumes the same target_cls as this tuple/list call
  1707. unpacked.append(
  1708. wrap_fx_proxy_cls(
  1709. target_cls=target_cls,
  1710. tx=tx,
  1711. proxy=proxy_i,
  1712. example_value=val,
  1713. **options_i,
  1714. )
  1715. )
  1716. if isinstance(example_value, torch.Size):
  1717. # NB: Keep the old proxy around. See SizeVariable for an
  1718. # explanation why
  1719. return SizeVariable(unpacked, proxy, **options)
  1720. elif istype(example_value, tuple):
  1721. return TupleVariable(unpacked, **options)
  1722. elif istype(example_value, (list, immutable_list)):
  1723. return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
  1724. else:
  1725. assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
  1726. example_value, "_fields"
  1727. ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
  1728. return NamedTupleVariable(unpacked, example_value.__class__, **options)
  1729. elif example_value is None or proxy.node.target is torch.manual_seed:
  1730. return ConstantVariable.create(None, **options)
  1731. elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  1732. set_example_value(proxy.node, example_value)
  1733. return SymNodeVariable(proxy, example_value, **options)
  1734. elif (
  1735. inspect.isclass(proxy.node.target)
  1736. and issubclass(proxy.node.target, _StreamBase)
  1737. ) or proxy.node.target in [
  1738. device_interface.current_stream
  1739. for _, device_interface in get_registered_device_interfaces()
  1740. ]:
  1741. set_example_value(proxy.node, example_value)
  1742. return StreamVariable(proxy, example_value, example_value.device, **options)
  1743. elif (
  1744. inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
  1745. ) or proxy.node.target in [
  1746. device_interface.Event
  1747. for _, device_interface in get_registered_device_interfaces()
  1748. ]:
  1749. set_example_value(proxy.node, example_value)
  1750. return EventVariable(proxy, example_value, **options)
  1751. elif proxy.node.target == "query" and proxy.node.op == "call_method":
  1752. set_example_value(proxy.node, example_value)
  1753. return ConstantVariable(example_value, **options)
  1754. elif (
  1755. example_value is not None
  1756. and isinstance(example_value, _EventBase)
  1757. and proxy.node.target == "record_event"
  1758. and proxy.node.op == "call_method"
  1759. ):
  1760. set_example_value(proxy.node, example_value)
  1761. return EventVariable(proxy, example_value, **options)
  1762. elif isinstance(example_value, int) and proxy.node.target in [
  1763. torch.sym_int,
  1764. getattr,
  1765. operator.getitem,
  1766. torch._utils._element_size,
  1767. torch.seed,
  1768. operator.mod,
  1769. torch._functorch.vmap._validate_and_get_batch_size,
  1770. # some mac builds are missing torch.distributed.get_rank()
  1771. getattr(torch.distributed, "get_rank", _missing),
  1772. getattr(torch.distributed, "get_world_size", _missing),
  1773. # This always wants to be in the graph, even if the constraint
  1774. # results in a constant int
  1775. torch._constrain_as_size,
  1776. ]:
  1777. set_example_value(proxy.node, example_value)
  1778. return ConstantVariable.create(example_value, **options)
  1779. elif isinstance(example_value, torch.backends.cuda.SDPAParams):
  1780. from .sdpa import SDPAParamsVariable
  1781. set_example_value(proxy.node, example_value)
  1782. return SDPAParamsVariable(proxy, **options)
  1783. elif isinstance(example_value, bool) and proxy.node.target in [
  1784. torch.backends.cuda.can_use_flash_attention,
  1785. torch.backends.cuda.can_use_efficient_attention,
  1786. ]:
  1787. set_example_value(proxy.node, example_value)
  1788. return ConstantVariable.create(example_value, **options)
  1789. elif (
  1790. isinstance(example_value, (int, float, bool))
  1791. and proxy.node.target is call_torchbind
  1792. ):
  1793. set_example_value(proxy.node, example_value)
  1794. return ConstantVariable.create(example_value, **options)
  1795. else:
  1796. unimplemented(
  1797. "torch.* op returned non-Tensor "
  1798. + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
  1799. )
  1800. # Tracks the sources of all fake tensors we wrap in Dynamo.
  1801. # Used by shape guard computation.
  1802. @dataclasses.dataclass
  1803. class TrackedFake:
  1804. fake: Union[FakeTensor, SymInt]
  1805. source: Source
  1806. # Is None when fake is SymInt
  1807. symbolic_context: Optional[SymbolicContext]
  1808. def __hash__(self) -> int:
  1809. return hash((self.fake, self.source.name()))
  1810. def __eq__(self, other: object) -> bool:
  1811. if isinstance(other, TrackedFake):
  1812. return self.fake is other.fake and self.source.name() == other.source.name()
  1813. return False
  1814. # Performs automatic dynamic dim determination.
  1815. # Returns a SymbolicContext
  1816. def _automatic_dynamic(
  1817. e, tx, source, static_shapes, outer_only=False
  1818. ) -> SymbolicContext:
  1819. # strided NT not supported
  1820. if e.is_nested and not isinstance(
  1821. e, torch.nested._internal.nested_tensor.NestedTensor
  1822. ):
  1823. unimplemented("torch.compile does not support strided NestedTensor")
  1824. name = source.name()
  1825. prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
  1826. shape_env_to_source_to_symbol_cache = (
  1827. prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
  1828. )
  1829. # Get base context if the tensor is a view
  1830. view_base_context: Optional[SymbolicContext] = None
  1831. if e._is_view():
  1832. base_source = AttrSource(source, "_base")
  1833. view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes)
  1834. if is_traceable_wrapper_subclass(e) and not outer_only:
  1835. # Get symbolic context for outer tensor
  1836. outer_context = _automatic_dynamic(
  1837. e, tx, source, static_shapes, outer_only=True
  1838. )
  1839. # Get symbolic contexts for inner tensors
  1840. attrs, _ = type(e).__tensor_flatten__(e)
  1841. inner_contexts = {} # mapping from attr -> symbolic context
  1842. for attr in attrs:
  1843. inner_tensor = getattr(e, attr)
  1844. inner_source = AttrSource(source, attr)
  1845. inner_context = _automatic_dynamic(
  1846. inner_tensor, tx, inner_source, static_shapes
  1847. )
  1848. inner_contexts[attr] = inner_context
  1849. return SubclassSymbolicContext(
  1850. dynamic_sizes=outer_context.dynamic_sizes,
  1851. constraint_sizes=outer_context.constraint_sizes,
  1852. view_base_context=view_base_context,
  1853. tensor_source=outer_context.tensor_source,
  1854. shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache,
  1855. inner_contexts=inner_contexts,
  1856. )
  1857. if static_shapes:
  1858. return StatefulSymbolicContext(
  1859. dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
  1860. constraint_sizes=[None] * e.dim(),
  1861. view_base_context=view_base_context,
  1862. tensor_source=source,
  1863. shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
  1864. )
  1865. # We preserve the dynamism of inputs. For example, when users call
  1866. # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
  1867. from torch.fx.experimental.symbolic_shapes import is_nested_int
  1868. if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
  1869. return StatefulSymbolicContext(
  1870. dynamic_sizes=[
  1871. DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
  1872. for s in e.size()
  1873. ],
  1874. constraint_sizes=[None] * e.dim(),
  1875. view_base_context=view_base_context,
  1876. tensor_source=source,
  1877. shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
  1878. )
  1879. # Prep for automatic dynamic
  1880. frame_state_entry = None
  1881. if name not in tx.output.frame_state:
  1882. # If there is no entry for this source, add the tensor to frame state with its current static size.
  1883. # E.g., {} -> {"x": [2, 4]}
  1884. frame_state_entry = FrameStateSizeEntry(None, None)
  1885. frame_state_entry.size = list(e.size())
  1886. else:
  1887. frame_state_entry = tx.output.frame_state[name]
  1888. if frame_state_entry.size is not None:
  1889. if e.ndim != len(frame_state_entry.size):
  1890. # If there is already an entry, and the dim mismatches, replace the frame state entry with None.
  1891. # E.g. {"x": [2, 3, 4]} -> {"x": None}
  1892. log.debug(
  1893. "automatic dynamic %s dim %s != %s",
  1894. name,
  1895. e.ndim,
  1896. frame_state_entry.size,
  1897. )
  1898. frame_state_entry.size = None
  1899. else:
  1900. # If there is already an entry, and the dim matches, for every size in the frame state which
  1901. # disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]}
  1902. for i, dim in enumerate(frame_state_entry.size):
  1903. if dim is not None and e.size()[i] != dim:
  1904. log.debug(
  1905. "automatic dynamic %s size(%s) %s != %s",
  1906. name,
  1907. i,
  1908. e.size(i),
  1909. dim,
  1910. )
  1911. frame_state_entry.size[i] = None
  1912. # TODO: index export_constraints ahead of time so we don't have to
  1913. # do a linear scan every time here
  1914. t_id = id(e)
  1915. dim2constraint = {}
  1916. def update_dim2constraint(dim, constraint_range, debug_name):
  1917. if dim in dim2constraint:
  1918. from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
  1919. old_constraint_range, old_debug_name = dim2constraint[dim]
  1920. new_constraint_range = StrictMinMaxConstraint(
  1921. vr=constraint_range.vr & old_constraint_range.vr,
  1922. warn_only=False,
  1923. )
  1924. # It is possible for (non-None) old_debug_name and debug_name to be different
  1925. # but this will only happen the corresponding Dims can be derived equal.
  1926. new_debug_name = old_debug_name or debug_name
  1927. dim2constraint[dim] = new_constraint_range, new_debug_name
  1928. else:
  1929. dim2constraint[dim] = constraint_range, debug_name
  1930. if tx.output.export_constraints:
  1931. for constraint in tx.output.export_constraints:
  1932. if constraint.t_id == t_id:
  1933. update_dim2constraint(
  1934. constraint.dim, constraint.constraint_range, constraint.debug_name
  1935. )
  1936. if constraint.shared is not None and constraint.shared.t_id == t_id:
  1937. # We process constraint ranges for each shared dimension separately
  1938. # so that we can directly check range constraint violations on them
  1939. # without looking up which other shared dimensions have this info.
  1940. # In other words, for this t_id, we will have processed all of its
  1941. # constraint ranges, no matter where / how they were specified, by
  1942. # by the end of this loop.
  1943. update_dim2constraint(
  1944. constraint.shared.dim,
  1945. constraint.constraint_range,
  1946. constraint.debug_name,
  1947. )
  1948. dynamic_dims = []
  1949. constraint_dims = []
  1950. for i in range(e.dim()):
  1951. # NB: mark dynamic has precedence over static
  1952. marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
  1953. marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
  1954. marked_static = i in getattr(e, "_dynamo_static_indices", set())
  1955. # NB: both static and dynamic have precedence over
  1956. automatic_dynamic = config.automatic_dynamic_shapes and (
  1957. frame_state_entry.size is None or frame_state_entry.size[i] is None
  1958. )
  1959. # Reflect the user directive in the frame_state
  1960. # For dynamic, apply None always
  1961. if frame_state_entry.size and marked_dynamic:
  1962. log.debug("automatic dynamic %s marked dynamic", name)
  1963. frame_state_entry.size[i] = None
  1964. # We will process constraints first, as they will imply that we
  1965. # have a dynamic dimension
  1966. # Precedence: export constraints > eager constraints
  1967. constraint = dim2constraint.get(i)
  1968. if constraint is None:
  1969. if marked_dynamic and not config.allow_ignore_mark_dynamic:
  1970. if hasattr(e, "_dynamo_dynamic_range"):
  1971. dim_range = [
  1972. dr for dr in e._dynamo_dynamic_range if dr.dim == i
  1973. ].pop()
  1974. if dim_range.min is None and dim_range.max is None:
  1975. constraint_dim = RelaxedUnspecConstraint(warn_only=False)
  1976. else:
  1977. from torch.fx.experimental.symbolic_shapes import (
  1978. StrictMinMaxConstraint,
  1979. )
  1980. constraint_dim = StrictMinMaxConstraint(
  1981. vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
  1982. warn_only=False,
  1983. )
  1984. else:
  1985. constraint_dim = RelaxedUnspecConstraint(warn_only=False)
  1986. elif not marked_static and automatic_dynamic:
  1987. constraint_dim = RelaxedUnspecConstraint(warn_only=True)
  1988. else:
  1989. constraint_dim = None
  1990. else:
  1991. constraint_dim, debug_name = constraint
  1992. if debug_name is not None:
  1993. dim_name = f"{name}.size()[{i}]"
  1994. tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name
  1995. constraint_dims.append(constraint_dim)
  1996. # Now, figure out if the dim is dynamic/duck/static
  1997. if (
  1998. constraint_dim is not None
  1999. or marked_dynamic
  2000. or marked_weak_dynamic
  2001. or is_nested_int(e.shape[i])
  2002. ):
  2003. # NB: We could assert static_shapes is False here, but it
  2004. # seems better to allow the user to override symbolic_context in this
  2005. # case
  2006. dynamic = DimDynamic.DYNAMIC
  2007. elif static_shapes or config.assume_static_by_default or marked_static:
  2008. dynamic = DimDynamic.STATIC
  2009. else:
  2010. dynamic = DimDynamic.DUCK
  2011. dynamic_dims.append(dynamic)
  2012. tx.output.frame_state[name] = frame_state_entry
  2013. return StatefulSymbolicContext(
  2014. dynamic_sizes=dynamic_dims,
  2015. constraint_sizes=constraint_dims,
  2016. view_base_context=view_base_context,
  2017. tensor_source=source,
  2018. shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
  2019. )
  2020. # See note [Tensor Fakification and Symbol Caching]
  2021. def wrap_to_fake_tensor_and_record(
  2022. e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None
  2023. ):
  2024. if (
  2025. type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
  2026. or isinstance(e, torch.Tensor)
  2027. or is_traceable_wrapper_subclass(e)
  2028. ):
  2029. assert source is not None
  2030. static_shapes, reason = tensor_always_has_static_shape(
  2031. e, is_tensor, guard_source=source.guard_source()
  2032. )
  2033. if not parent_context:
  2034. symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
  2035. else:
  2036. # Parent contexts are passed in when we are recursively creating
  2037. # fake tensors for subclasses. A better design would be not to create a
  2038. # parent/child relationship, but to recursively call _automatic_dynamic
  2039. # as we recursively call wrap_to_fake_tensor_and_record. This runs
  2040. # into bugs around how meta_utils knows and works to create fake tensors
  2041. # with tensor subclasses. Ideally, dynamo would drive both the recursive
  2042. # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation.
  2043. assert isinstance(source, AttrSource)
  2044. inner_context_name = source.member
  2045. symbolic_context = parent_context.inner_contexts[inner_context_name]
  2046. log.debug(
  2047. "wrap_to_fake %s %s %s %s",
  2048. source.name(),
  2049. tuple(e.shape),
  2050. symbolic_context,
  2051. type(e),
  2052. )
  2053. fake_e = wrap_fake_exception(
  2054. lambda: tx.fake_mode.from_tensor(
  2055. e,
  2056. source=source,
  2057. symbolic_context=symbolic_context,
  2058. )
  2059. )
  2060. if (
  2061. source is not None
  2062. and isinstance(fake_e, FakeTensor)
  2063. and (sym_val := fake_e.item_memo) is not None
  2064. ):
  2065. tx.output.tracked_fakes.append(
  2066. TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context)
  2067. )
  2068. if is_traceable_wrapper_subclass(fake_e):
  2069. attrs, _ = fake_e.__tensor_flatten__()
  2070. for attr in attrs:
  2071. fake_inner = getattr(fake_e, attr)
  2072. inner = getattr(e, attr)
  2073. inner_source = AttrSource(source, attr)
  2074. wrap_to_fake_tensor_and_record(
  2075. inner,
  2076. tx,
  2077. source=inner_source,
  2078. is_tensor=isinstance(fake_inner, torch.Tensor),
  2079. parent_context=symbolic_context,
  2080. )
  2081. tx.output.tracing_context.tensor_to_context[e] = symbolic_context
  2082. if is_sparse_any(fake_e):
  2083. # TODO: for TensorGuards, this eventually may need more
  2084. # fields for the size/stride of any other constituents
  2085. values = fake_e._values() if fake_e.is_sparse else fake_e.values()
  2086. tx.output.input_source_to_sizes_strides[source] = {
  2087. "size": fake_e.size(),
  2088. # TODO: revise this, but for now this stride instead of ()
  2089. # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1
  2090. "stride": (1,) * fake_e.ndim,
  2091. "values_size": values.size(),
  2092. "values_stride": values.stride(),
  2093. }
  2094. else:
  2095. tx.output.input_source_to_sizes_strides[source] = {
  2096. "size": fake_e.size(),
  2097. "stride": fake_e.stride(),
  2098. }
  2099. if (
  2100. is_tensor
  2101. and not (static_shapes and source.is_nn_module())
  2102. and not is_constant_source(source)
  2103. ):
  2104. tx.output.tracked_fakes.append(
  2105. TrackedFake(fake_e, source, symbolic_context)
  2106. )
  2107. tx.output.tracked_fakes_id_to_source[id(e)].append(source)
  2108. return fake_e
  2109. else:
  2110. return e
  2111. class SourcelessBuilder:
  2112. """
  2113. Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
  2114. that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
  2115. .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
  2116. there may be reasons to represent it as a ListVariable internally.
  2117. NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
  2118. NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
  2119. if/else type->VariableTracker trees that were cropping up all over dynamo.
  2120. """
  2121. def __init__(self):
  2122. raise AssertionError("Use SourcelessBuilder.create()")
  2123. @staticmethod
  2124. def create(tx, value) -> VariableTracker:
  2125. value_type = type(value)
  2126. fast_handler = SourcelessBuilder._type_handlers.get(value_type)
  2127. if fast_handler:
  2128. return fast_handler(tx, value)
  2129. if isinstance(value, VariableTracker):
  2130. # This is always valid to call, and useful for recursive calls.
  2131. return value
  2132. elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
  2133. return UserDefinedObjectVariable(value)
  2134. elif ConstantVariable.is_literal(value):
  2135. return ConstantVariable.create(value)
  2136. elif callable(value) and trace_rules.lookup_callable(value) is not None:
  2137. if is_callable_allowed(value):
  2138. tx.output.has_user_defined_allowed_in_graph = True
  2139. return trace_rules.lookup_callable(value)(value)
  2140. elif is_function_or_wrapper(value):
  2141. return trace_rules.lookup(value)(value)
  2142. elif isinstance(value, enum.Enum):
  2143. return EnumVariable(value)
  2144. elif isinstance(value, (type, abc.ABCMeta)):
  2145. return UserDefinedClassVariable(value)
  2146. elif isinstance(value, types.MethodWrapperType):
  2147. return MethodWrapperVariable(value)
  2148. elif isinstance(value, torch.fx.graph_module.GraphModule):
  2149. return SourcelessGraphModuleVariable(value)
  2150. elif isinstance(
  2151. value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
  2152. ):
  2153. return UserDefinedObjectVariable(value)
  2154. elif PlacementVariable.is_placement(value):
  2155. return PlacementVariable(value)
  2156. elif DeviceMeshVariable.is_device_mesh(value):
  2157. return DeviceMeshVariable(value)
  2158. elif isinstance(value, re.Pattern):
  2159. return RegexPatternVariable(value)
  2160. unimplemented(
  2161. f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}"
  2162. )
  2163. @staticmethod
  2164. def wrap_constant_literal(value):
  2165. assert ConstantVariable.is_literal(value)
  2166. return ConstantVariable.create(value=value)
  2167. @staticmethod
  2168. def make_type_handlers():
  2169. create = SourcelessBuilder.create
  2170. handlers = {}
  2171. for t in common_constant_types:
  2172. handlers[t] = lambda tx, value: ConstantVariable(value)
  2173. handlers[set] = lambda tx, value: SetVariable(
  2174. [create(tx, x) for x in value], mutable_local=MutableLocal()
  2175. )
  2176. handlers[dict] = lambda tx, value: ConstDictVariable(
  2177. {create(tx, k): create(tx, v) for k, v in value.items()},
  2178. type(value),
  2179. mutable_local=MutableLocal(),
  2180. )
  2181. handlers[list] = lambda tx, value: ListVariable(
  2182. [create(tx, x) for x in value], mutable_local=MutableLocal()
  2183. )
  2184. handlers[tuple] = lambda tx, value: TupleVariable(
  2185. [create(tx, x) for x in value]
  2186. )
  2187. handlers[torch.Size] = lambda tx, value: SizeVariable(
  2188. [create(tx, x) for x in value]
  2189. )
  2190. handlers[collections.OrderedDict] = handlers[dict]
  2191. handlers[immutable_dict] = handlers[dict]
  2192. handlers[immutable_list] = handlers[list]
  2193. handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
  2194. def passthrough(tx, value):
  2195. return value
  2196. for cls in VariableTrackerMeta.all_subclasses:
  2197. handlers[cls] = passthrough
  2198. return handlers
  2199. SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()