user_defined.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107
  1. # mypy: ignore-errors
  2. import collections
  3. import contextlib
  4. import enum
  5. import functools
  6. import importlib
  7. import inspect
  8. import itertools
  9. import random
  10. import re
  11. import sys
  12. import threading
  13. import types
  14. import warnings
  15. from typing import Dict, Generic, List
  16. from ..bytecode_transformation import create_call_function
  17. try:
  18. import numpy as np
  19. except ModuleNotFoundError:
  20. np = None
  21. try:
  22. from torch.utils._cxx_pytree import PyTreeSpec
  23. except ImportError:
  24. PyTreeSpec = type(None)
  25. import torch._dynamo.config
  26. import torch.nn
  27. from torch._guards import TracingContext
  28. from .. import variables
  29. from ..exc import unimplemented
  30. from ..guards import GuardBuilder, install_guard
  31. from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource
  32. from ..utils import (
  33. all_hook_names,
  34. build_checkpoint_variable,
  35. check_constant_args,
  36. get_custom_getattr,
  37. has_torch_function,
  38. is_namedtuple_cls,
  39. is_utils_checkpoint,
  40. istype,
  41. namedtuple_fields,
  42. object_has_getattribute,
  43. proxy_args_kwargs,
  44. tensortype_to_dtype,
  45. )
  46. from .base import MutableLocal, VariableTracker
  47. from .ctx_manager import GenericContextWrappingVariable, NullContextVariable
  48. from .dicts import DefaultDictVariable
  49. def is_standard_setattr(val):
  50. return val in (
  51. object.__setattr__,
  52. torch.nn.Module.__setattr__,
  53. )
  54. class UserDefinedVariable(VariableTracker):
  55. pass
  56. class UserDefinedClassVariable(UserDefinedVariable):
  57. def __init__(self, value, **kwargs):
  58. super().__init__(**kwargs)
  59. self.value = value
  60. def as_python_constant(self):
  61. return self.value
  62. def python_type(self):
  63. return type(self.value)
  64. def as_proxy(self):
  65. return self.value
  66. def __str__(self):
  67. return f"UserDefinedClassVariable({self.value})"
  68. @staticmethod
  69. @functools.lru_cache(None)
  70. def _constant_fold_classes():
  71. return {
  72. torch.device,
  73. torch.finfo,
  74. torch.iinfo,
  75. torch.Size,
  76. }
  77. @staticmethod
  78. @functools.lru_cache(None)
  79. def _in_graph_classes():
  80. return set(tensortype_to_dtype.keys()) | {
  81. torch.Tensor,
  82. torch.cuda.Stream,
  83. torch.cuda.Event,
  84. }
  85. def can_constant_fold_through(self):
  86. return self.value in self._constant_fold_classes()
  87. def var_getattr(self, tx, name: str) -> "VariableTracker":
  88. from .. import trace_rules
  89. from . import ConstantVariable, EnumVariable
  90. from .builder import VariableBuilder
  91. if name == "__name__":
  92. return ConstantVariable.create(self.value.__name__)
  93. elif name == "__qualname__":
  94. return ConstantVariable.create(self.value.__qualname__)
  95. source = AttrSource(self.source, name) if self.source is not None else None
  96. try:
  97. obj = inspect.getattr_static(self.value, name)
  98. except AttributeError:
  99. obj = None
  100. if isinstance(obj, staticmethod):
  101. func = obj.__get__(self.value)
  102. if source is not None:
  103. return trace_rules.lookup(func).create_with_source(func, source=source)
  104. else:
  105. return trace_rules.lookup(func)(func)
  106. elif isinstance(obj, classmethod):
  107. return variables.UserMethodVariable(obj.__func__, self, source=source)
  108. elif source:
  109. # __mro__ is a member in < 3.12, an attribute in >= 3.12
  110. if inspect.ismemberdescriptor(obj) or (
  111. sys.version_info >= (3, 12) and name == "__mro__"
  112. ):
  113. return VariableBuilder(tx, source)(obj.__get__(self.value))
  114. # Special handling of collections.OrderedDict.fromkeys()
  115. # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with
  116. # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method().
  117. # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys),
  118. # and we need duplicate code to handle both cases.
  119. if self.value is collections.OrderedDict and name == "fromkeys":
  120. return super().var_getattr(tx, name)
  121. if ConstantVariable.is_literal(obj):
  122. return ConstantVariable.create(obj)
  123. elif isinstance(obj, enum.Enum):
  124. return EnumVariable(obj)
  125. elif name in getattr(self.value, "__dict__", {}) or (
  126. self.value.__module__.startswith("torch.")
  127. or self.value.__module__ == "torch"
  128. ):
  129. if source:
  130. return VariableBuilder(tx, source)(obj)
  131. return super().var_getattr(tx, name)
  132. def _call_cross_entropy_loss(self, tx, args, kwargs):
  133. """
  134. functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
  135. label_smoothing=0.0
  136. non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
  137. label_smoothing=0.0
  138. non functional loss call: input, target, optional_output
  139. """
  140. from . import ConstantVariable
  141. def normalize_args(
  142. weight=ConstantVariable.create(None),
  143. size_average=ConstantVariable.create(None),
  144. ignore_index=ConstantVariable.create(-100),
  145. reduce=ConstantVariable.create(None),
  146. reduction=ConstantVariable.create("mean"),
  147. label_smoothing=ConstantVariable.create(0.0),
  148. ):
  149. return (
  150. weight,
  151. size_average,
  152. ignore_index,
  153. reduce,
  154. reduction,
  155. label_smoothing,
  156. )
  157. (
  158. weight,
  159. size_average,
  160. ignore_index,
  161. reduce_arg,
  162. reduction,
  163. label_smoothing,
  164. ) = normalize_args(*args, **kwargs)
  165. def fake_cross_entropy_loss(input, target):
  166. from .builder import wrap_fx_proxy
  167. return wrap_fx_proxy(
  168. tx=tx,
  169. proxy=tx.output.create_proxy(
  170. "call_function",
  171. torch.nn.functional.cross_entropy,
  172. *proxy_args_kwargs(
  173. [
  174. input,
  175. target,
  176. weight,
  177. size_average,
  178. ignore_index,
  179. reduce_arg,
  180. reduction,
  181. label_smoothing,
  182. ],
  183. {},
  184. ),
  185. ),
  186. )
  187. return variables.LambdaVariable(fake_cross_entropy_loss)
  188. def call_method(
  189. self,
  190. tx,
  191. name,
  192. args: "List[VariableTracker]",
  193. kwargs: "Dict[str, VariableTracker]",
  194. ) -> "VariableTracker":
  195. if (
  196. name == "__subclasses__"
  197. and len(args) == 0
  198. and not kwargs
  199. and "__subclasses__" not in self.value.__dict__
  200. ):
  201. options = {"mutable_local": MutableLocal()}
  202. subs_as_vars: List[VariableTracker] = list()
  203. for sub in self.value.__subclasses__():
  204. source = AttrSource(tx.import_source(sub.__module__), sub.__name__)
  205. subs_as_vars.append(
  206. variables.UserDefinedClassVariable(sub, source=source)
  207. )
  208. return variables.ListVariable(subs_as_vars, **options)
  209. elif (
  210. self.value in {collections.OrderedDict, collections.defaultdict}
  211. and name == "fromkeys"
  212. ):
  213. from .builtin import BuiltinVariable
  214. return BuiltinVariable.call_custom_dict_fromkeys(
  215. tx, self.value, *args, **kwargs
  216. )
  217. elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"):
  218. return variables.ConstantVariable(self.value == args[0].value)
  219. elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
  220. return variables.ConstantVariable(self.value != args[0].value)
  221. return super().call_method(tx, name, args, kwargs)
  222. def call_function(
  223. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  224. ) -> "VariableTracker":
  225. from ..side_effects import SideEffects
  226. from .builder import SourcelessBuilder, wrap_fx_proxy
  227. from .builtin import BuiltinVariable
  228. constant_args = check_constant_args(args, kwargs)
  229. if self.can_constant_fold_through() and constant_args:
  230. # constant fold
  231. return variables.ConstantVariable.create(
  232. self.as_python_constant()(
  233. *[x.as_python_constant() for x in args],
  234. **{k: v.as_python_constant() for k, v in kwargs.items()},
  235. ),
  236. )
  237. elif self.value is torch.nn.CrossEntropyLoss:
  238. return self._call_cross_entropy_loss(tx, args, kwargs)
  239. elif self.value is contextlib.nullcontext:
  240. return NullContextVariable()
  241. elif self.value is collections.OrderedDict:
  242. return BuiltinVariable.call_custom_dict(
  243. tx, collections.OrderedDict, *args, **kwargs
  244. )
  245. elif (
  246. self.value is collections.defaultdict
  247. and len(args) <= 1
  248. and DefaultDictVariable.is_supported_arg(args[0])
  249. ):
  250. return DefaultDictVariable(
  251. {},
  252. collections.defaultdict,
  253. args[0],
  254. mutable_local=MutableLocal(),
  255. )
  256. elif self.value is collections.deque and not kwargs:
  257. if len(args) == 0:
  258. items = []
  259. elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
  260. items = args[0].unpack_var_sequence(tx)
  261. else:
  262. unimplemented("deque() with more than 1 arg not supported")
  263. return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
  264. elif self.value is functools.partial:
  265. if not args:
  266. unimplemented("functools.partial malformed")
  267. # The first arg, a callable (the ctor below will assert on types)
  268. fn = args[0]
  269. rest_args = args[1:]
  270. # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
  271. # args and keywords
  272. return variables.functions.FunctoolsPartialVariable(
  273. fn, args=rest_args, keywords=kwargs
  274. )
  275. elif self.value is warnings.catch_warnings and not args:
  276. return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
  277. elif (
  278. issubclass(type(self.value), type)
  279. and hasattr(
  280. self.value, "__enter__"
  281. ) # TODO(voz): These can invoke user code!
  282. and hasattr(
  283. self.value, "__exit__"
  284. ) # TODO(voz): These can invoke user code!
  285. and check_constant_args(args, kwargs)
  286. and self.value.__init__ == object.__init__
  287. and len(kwargs) == 0 # TODO(ybliang): support kwargs
  288. ):
  289. unwrapped_args = [x.as_python_constant() for x in args]
  290. return GenericContextWrappingVariable(
  291. unwrapped_args,
  292. cm_obj=self.value(*unwrapped_args),
  293. )
  294. elif is_namedtuple_cls(self.value):
  295. fields = namedtuple_fields(self.value)
  296. # check if this a quasi-namedtuple or a real one
  297. if self.value.__module__ == "torch.return_types":
  298. # create pseudo-defaults from values of the quasi-namedtuple
  299. field_defaults = dict(zip(fields, args[0].items))
  300. else:
  301. field_defaults = self.value._field_defaults
  302. items = list(args)
  303. items.extend([None] * (len(fields) - len(items)))
  304. var_tracker_kwargs = {}
  305. for field_name, var_tracker in zip(fields, items):
  306. if var_tracker is None:
  307. if field_name in kwargs:
  308. field_var = kwargs[field_name]
  309. else:
  310. assert field_name in field_defaults
  311. field_var = SourcelessBuilder.create(
  312. tx, field_defaults[field_name]
  313. )
  314. var_tracker_kwargs[field_name] = field_var
  315. for name, value in var_tracker_kwargs.items():
  316. assert name in fields
  317. items[fields.index(name)] = value
  318. assert all(x is not None for x in items)
  319. return variables.NamedTupleVariable(items, self.value)
  320. elif (
  321. self.is_standard_new()
  322. and SideEffects.cls_supports_mutation_side_effects(self.value)
  323. and self.source
  324. ):
  325. var = tx.output.side_effects.track_object_new(
  326. self.source,
  327. self.value,
  328. variables.UnspecializedNNModuleVariable
  329. if issubclass(self.value, torch.nn.Module)
  330. else UserDefinedObjectVariable,
  331. {},
  332. )
  333. if (
  334. inspect.getattr_static(self.value, "__init__", None)
  335. is torch.nn.Module.__init__
  336. ):
  337. tx.output.side_effects.store_attr(
  338. var,
  339. "__call_nn_module_init",
  340. variables.ConstantVariable.create(True),
  341. )
  342. return var
  343. else:
  344. var.call_method(tx, "__init__", args, kwargs)
  345. return var
  346. elif variables.CustomizedDictVariable.is_matching_cls(self.value):
  347. options = {"mutable_local": MutableLocal()}
  348. return variables.CustomizedDictVariable.create(
  349. self.value, args, kwargs, options
  350. )
  351. elif variables.DataClassVariable.is_matching_cls(self.value):
  352. options = {"mutable_local": MutableLocal()}
  353. return variables.DataClassVariable.create(self.value, args, kwargs, options)
  354. elif (
  355. variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
  356. and self.source
  357. ):
  358. return variables.RestrictedListSubclassVariable(
  359. variables.BuiltinVariable(list).call_function(tx, args, kwargs).items,
  360. user_cls=self.value,
  361. user_cls_source=self.source,
  362. mutable_local=MutableLocal(),
  363. )
  364. elif self.value in self._in_graph_classes():
  365. # torch.LongTensor cannot accept a list of FakeTensors.
  366. # So we stack the list of FakeTensors instead.
  367. if (
  368. np
  369. and self.value in tensortype_to_dtype
  370. and len(args) == 1
  371. and isinstance(args[0], variables.ListVariable)
  372. and len(args[0].items) > 1
  373. and all(isinstance(x, variables.TensorVariable) for x in args[0].items)
  374. ):
  375. # Stack FakeTensor
  376. stacked = wrap_fx_proxy(
  377. tx=tx,
  378. proxy=tx.output.create_proxy(
  379. "call_function",
  380. torch.stack,
  381. *proxy_args_kwargs(args, kwargs),
  382. ),
  383. )
  384. args = [stacked]
  385. tensor_variable = wrap_fx_proxy(
  386. tx=tx,
  387. proxy=tx.output.create_proxy(
  388. "call_function",
  389. self.value,
  390. *proxy_args_kwargs(args, kwargs),
  391. ),
  392. )
  393. return tensor_variable
  394. elif issubclass(self.value, enum.Enum) and len(args) == 1 and not kwargs:
  395. options = {"mutable_local": MutableLocal()}
  396. return variables.EnumVariable.create(self.value, args[0], options)
  397. return super().call_function(tx, args, kwargs)
  398. def is_standard_new(self):
  399. """Check for __new__ being overridden"""
  400. new_fn = inspect.getattr_static(self.value, "__new__", None)
  401. if isinstance(new_fn, staticmethod):
  402. new_fn = new_fn.__func__
  403. return new_fn in (object.__new__, Generic.__new__)
  404. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  405. if self.source:
  406. source = AttrSource(self.source, name)
  407. install_guard(source.make_guard(GuardBuilder.HASATTR))
  408. return variables.ConstantVariable(hasattr(self.value, name))
  409. return super().call_hasattr(tx, name)
  410. def const_getattr(self, tx, name):
  411. if name == "__name__":
  412. return self.value.__name__
  413. return super().const_getattr(tx, name)
  414. class NO_SUCH_SUBOBJ:
  415. pass
  416. class UserDefinedObjectVariable(UserDefinedVariable):
  417. """
  418. Mostly objects of defined type. Catch-all for something where we only know the type.
  419. """
  420. _nonvar_fields = {"value", "value_type", *UserDefinedVariable._nonvar_fields}
  421. def __init__(self, value, value_type=None, **kwargs):
  422. super().__init__(**kwargs)
  423. self.value = value
  424. self.value_type = value_type or type(value)
  425. assert type(value) is self.value_type
  426. def __str__(self):
  427. inner = self.value_type.__name__
  428. if inner in [
  429. "builtin_function_or_method",
  430. "getset_descriptor",
  431. "method_descriptor",
  432. "method",
  433. ]:
  434. inner = str(getattr(self.value, "__name__", None))
  435. return f"{self.__class__.__name__}({inner})"
  436. def python_type(self):
  437. return self.value_type
  438. def guard_as_python_constant(self):
  439. if self.source:
  440. install_guard(self.source.make_guard(GuardBuilder.ID_MATCH))
  441. return self.value
  442. return super().guard_as_python_constant()
  443. def torch_function_check(self):
  444. assert has_torch_function(
  445. self
  446. ), f"calling torch function on object without __torch_function__ {self}"
  447. def get_torch_fn(self, tx):
  448. self.torch_function_check()
  449. from .torch_function import build_torch_function_fn
  450. return build_torch_function_fn(tx, self.value, self.source)
  451. def call_torch_function(self, tx, fn, types, args, kwargs):
  452. self.torch_function_check()
  453. from .torch_function import _get_subclass_type_var, call_torch_function
  454. return call_torch_function(
  455. tx,
  456. _get_subclass_type_var(tx, self),
  457. self.get_torch_fn(tx),
  458. fn,
  459. types,
  460. args,
  461. kwargs,
  462. )
  463. @staticmethod
  464. @functools.lru_cache(None)
  465. def _supported_random_functions():
  466. fns = {
  467. random.random,
  468. random.randint,
  469. random.randrange,
  470. random.uniform,
  471. }
  472. return fns
  473. def _maybe_get_baseclass_method(self, name):
  474. if name not in getattr(self.value, "__dict__", {}):
  475. try:
  476. return inspect.getattr_static(type(self.value), name)
  477. except AttributeError:
  478. pass
  479. return None
  480. def call_method(
  481. self,
  482. tx,
  483. name,
  484. args: "List[VariableTracker]",
  485. kwargs: "Dict[str, VariableTracker]",
  486. ) -> "VariableTracker":
  487. from . import (
  488. BuiltinVariable,
  489. ConstantVariable,
  490. TupleVariable,
  491. UserMethodVariable,
  492. )
  493. method = self._maybe_get_baseclass_method(name)
  494. if method is not None:
  495. if method is object.__init__:
  496. return ConstantVariable.create(None)
  497. if is_standard_setattr(method):
  498. return self.method_setattr_standard(tx, *args, **kwargs)
  499. # [NOTE] OrderedDict, dict subtypes must always have source
  500. # We cannot instantiate such subtypes in-graph due to builtin __new__
  501. if method is collections.OrderedDict.keys:
  502. # subclass of OrderedDict
  503. assert not (args or kwargs)
  504. assert self.source # OrderedDict, dict subtypes must always have source
  505. keys = list(self.value.keys())
  506. assert all(map(ConstantVariable.is_literal, keys))
  507. install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
  508. tx.output.guard_on_key_order.add(self.source.name())
  509. return TupleVariable([ConstantVariable.create(k) for k in keys])
  510. if (
  511. method in (collections.OrderedDict.__contains__, dict.__contains__)
  512. and len(args) == 1
  513. and isinstance(args[0], (ConstantVariable, BuiltinVariable))
  514. and inspect.getattr_static(type(self.value), "keys")
  515. in (collections.OrderedDict.keys, dict.keys)
  516. ):
  517. assert not kwargs
  518. assert self.source # OrderedDict, dict subtypes must always have source
  519. # TODO(anijain2305) - Why do we need to guard on all keys?
  520. install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
  521. return ConstantVariable.create(
  522. args[0].as_python_constant() in self.value
  523. )
  524. if method is collections.OrderedDict.items and isinstance(
  525. self.value, collections.OrderedDict
  526. ):
  527. assert self.source # OrderedDict, dict subtypes must always have source
  528. assert not (args or kwargs)
  529. items = []
  530. keys = self.call_method(tx, "keys", [], {})
  531. for key in keys.unpack_var_sequence(tx):
  532. items.append(
  533. TupleVariable(
  534. [key, self.odict_getitem(tx, key)],
  535. )
  536. )
  537. tx.output.guard_on_key_order.add(self.source.name())
  538. return TupleVariable(items)
  539. if method is collections.OrderedDict.__getitem__ and len(args) == 1:
  540. assert not kwargs
  541. assert self.source # OrderedDict, dict subtypes must always have source
  542. return self.odict_getitem(tx, args[0])
  543. if (
  544. method in (object.__ne__, object.__eq__)
  545. and len(args) == 1
  546. and not kwargs
  547. and hasattr(args[0], "value")
  548. ):
  549. return ConstantVariable(
  550. (self.value is args[0].value) is (method is object.__eq__)
  551. )
  552. # check for methods implemented in C++
  553. if isinstance(method, types.FunctionType):
  554. source = (
  555. None
  556. if self.source is None
  557. else AttrSource(AttrSource(self.source, "__class__"), name)
  558. )
  559. # TODO(jansel): add a guard to check for monkey patching?
  560. return UserMethodVariable(method, self, source=source).call_function(
  561. tx, args, kwargs
  562. )
  563. if method is list.__len__ and self.source and not (args or kwargs):
  564. install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
  565. return ConstantVariable(len(self.value))
  566. return super().call_method(tx, name, args, kwargs)
  567. def method_setattr_standard(self, tx, name, value):
  568. try:
  569. name = name.as_python_constant()
  570. except NotImplementedError:
  571. unimplemented(f"non-const setattr name: {name}")
  572. if not tx.output.side_effects.is_attribute_mutation(self):
  573. unimplemented(f"setattr({self}, {name}, ...)")
  574. tx.output.side_effects.store_attr(self, name, value)
  575. return variables.ConstantVariable(None)
  576. def needs_slow_setattr(self):
  577. return not is_standard_setattr(
  578. inspect.getattr_static(self.value, "__setattr__", None)
  579. )
  580. def unpack_var_sequence(self, tx):
  581. if (
  582. self.source
  583. and self._maybe_get_baseclass_method("__iter__") is list.__iter__
  584. and self._maybe_get_baseclass_method("__len__") is list.__len__
  585. and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__
  586. ):
  587. install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
  588. return [
  589. variables.LazyVariableTracker.create(
  590. self.value[k],
  591. source=GetItemSource(self.source, k),
  592. )
  593. for k in range(len(self.value))
  594. ]
  595. return super().unpack_var_sequence(tx)
  596. def next_variable(self, tx):
  597. return self.call_method(tx, "__next__", [], {})
  598. def is_supported_random(self):
  599. try:
  600. return self.value in self._supported_random_functions()
  601. except TypeError:
  602. # TypeError: unhashable type
  603. return False
  604. def call_function(
  605. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  606. ) -> "VariableTracker":
  607. from .. import trace_rules
  608. from .builder import VariableBuilder
  609. if (
  610. self.is_supported_random()
  611. and all(k.is_python_constant() for k in args)
  612. and all(v.is_python_constant() for v in kwargs.values())
  613. ):
  614. args = [x.as_python_constant() for x in args]
  615. kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  616. random_call_index = len(tx.output.random_calls)
  617. example_value = self.value(*args, **kwargs)
  618. source = RandomValueSource(random_call_index)
  619. tx.output.random_calls.append((self.value, args, kwargs))
  620. # TODO: arguably, this should route to wrap_symint/wrap_symfloat
  621. # (currently hypothetical), but I'm not going to poke my hand in
  622. # this nest for now
  623. return VariableBuilder(tx, source).wrap_unspecialized_primitive(
  624. example_value
  625. )
  626. elif istype(self.value, types.MethodType):
  627. func = self.value.__func__
  628. obj = self.value.__self__
  629. if (
  630. func is torch.utils._contextlib._DecoratorContextManager.clone
  631. and variables.TorchCtxManagerClassVariable.is_matching_cls(
  632. obj.__class__
  633. )
  634. and not (args or kwargs)
  635. ):
  636. return variables.TorchCtxManagerClassVariable(
  637. obj.__class__
  638. ).call_function(tx, args, kwargs)
  639. if (
  640. func is torch.autograd.grad_mode.inference_mode.clone
  641. and obj.__class__ is torch.autograd.grad_mode.inference_mode
  642. ):
  643. # simulate the inference_mode.clone implementation
  644. var = variables.ConstantVariable(obj.mode)
  645. return variables.TorchCtxManagerClassVariable(
  646. obj.__class__
  647. ).call_function(tx, [var], kwargs)
  648. if self.source is None:
  649. unimplemented(
  650. "Sourceless UserDefinedObjectVariable method not supported"
  651. )
  652. func_src = AttrSource(self.source, "__func__")
  653. func_var = VariableBuilder(tx, func_src)(func)
  654. obj_src = AttrSource(self.source, "__self__")
  655. obj_var = VariableBuilder(tx, obj_src)(obj)
  656. return func_var.call_function(tx, [obj_var] + args, kwargs)
  657. elif (
  658. istype(self.value, functools.partial)
  659. and trace_rules.lookup(self.value.func)
  660. == variables.TorchInGraphFunctionVariable
  661. and all(
  662. variables.ConstantVariable.is_literal(v)
  663. for v in itertools.chain(self.value.args, self.value.keywords.values())
  664. )
  665. ):
  666. if self.source:
  667. install_guard(
  668. AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
  669. AttrSource(self.source, "args").make_guard(
  670. GuardBuilder.CONSTANT_MATCH
  671. ),
  672. AttrSource(self.source, "keywords").make_guard(
  673. GuardBuilder.CONSTANT_MATCH
  674. ),
  675. )
  676. partial_args = [
  677. variables.ConstantVariable.create(v) for v in self.value.args
  678. ]
  679. partial_args.extend(args)
  680. partial_kwargs = {
  681. k: variables.ConstantVariable.create(v)
  682. for k, v in self.value.keywords.items()
  683. }
  684. partial_kwargs.update(kwargs)
  685. if is_utils_checkpoint(self.value.func):
  686. return build_checkpoint_variable().call_function(
  687. tx, partial_args, partial_kwargs
  688. )
  689. return variables.TorchInGraphFunctionVariable(
  690. self.value.func
  691. ).call_function(tx, partial_args, partial_kwargs)
  692. elif callable(self.value):
  693. if self.source:
  694. install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
  695. return self.call_method(tx, "__call__", args, kwargs)
  696. return super().call_function(tx, args, kwargs)
  697. def _check_for_getattribute(self):
  698. if object_has_getattribute(self.value):
  699. unimplemented("UserDefinedObjectVariable with custom __getattribute__")
  700. def _check_for_getattr(self):
  701. return get_custom_getattr(self.value)
  702. def _getattr_static(self, name):
  703. if (
  704. isinstance(self.value, (torch.nn.Module, PyTreeSpec))
  705. or "__slots__" in self.value.__class__.__dict__
  706. or type(self.value) == threading.local
  707. ):
  708. try:
  709. cls_var = inspect.getattr_static(
  710. self.value.__class__, name, NO_SUCH_SUBOBJ
  711. )
  712. if cls_var is not NO_SUCH_SUBOBJ and name not in self.value.__dict__:
  713. # maybe user-defined @property that we need to inline
  714. return cls_var
  715. except AttributeError:
  716. pass # __slots__
  717. # this might call torch.nn.Module.__getattr__
  718. subobj = getattr(self.value, name)
  719. else:
  720. subobj = inspect.getattr_static(self.value, name)
  721. return subobj
  722. def has_key_in_generic_dict(self, tx, key):
  723. self._check_for_getattribute()
  724. if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
  725. mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
  726. return not isinstance(mutated_attr, variables.DeletedVariable)
  727. return key in self.value.__dict__
  728. def var_getattr(self, tx, name):
  729. from .. import trace_rules
  730. from . import ConstantVariable
  731. from .builder import VariableBuilder
  732. value = self.value
  733. source = AttrSource(self.source, name) if self.source else None
  734. self._check_for_getattribute()
  735. if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
  736. return tx.output.side_effects.load_attr(self, name)
  737. if name == "__dict__":
  738. options = {"source": source}
  739. return variables.GetAttrVariable(self, name, **options)
  740. try:
  741. subobj = self._getattr_static(name)
  742. except AttributeError:
  743. subobj = NO_SUCH_SUBOBJ
  744. getattr_fn = self._check_for_getattr()
  745. if isinstance(getattr_fn, types.FunctionType):
  746. # Dynamo is going to trace the __getattr__ function with
  747. # args=name. Set the source accordingly.
  748. new_source = None
  749. if self.source:
  750. new_source = AttrSource(self.source, "__getattr__")
  751. return variables.UserMethodVariable(
  752. getattr_fn, self, source=new_source
  753. ).call_function(tx, [ConstantVariable.create(name)], {})
  754. elif getattr_fn is not None:
  755. unimplemented("UserDefined with non-function __getattr__")
  756. if isinstance(subobj, property):
  757. if self.source:
  758. # Read the class attribute to reach the property
  759. source = AttrSource(AttrSource(self.source, "__class__"), name)
  760. # Get the getter function
  761. source = AttrSource(source, "fget")
  762. return variables.UserMethodVariable(
  763. subobj.fget, self, source=source
  764. ).call_function(tx, [], {})
  765. elif isinstance(subobj, torch.distributions.utils.lazy_property):
  766. subobj_var = UserDefinedObjectVariable(subobj, source=source)
  767. return variables.UserMethodVariable(
  768. subobj.__get__.__func__, subobj_var, source=source
  769. ).call_function(tx, [self], {})
  770. elif isinstance(subobj, staticmethod):
  771. func = subobj.__get__(self.value)
  772. if source is not None:
  773. return trace_rules.lookup(func).create_with_source(func, source=source)
  774. else:
  775. return trace_rules.lookup(func)(func)
  776. elif isinstance(subobj, classmethod):
  777. return variables.UserMethodVariable(
  778. subobj.__func__, self.var_getattr(tx, "__class__"), source=source
  779. )
  780. elif isinstance(subobj, types.FunctionType) or (
  781. isinstance(subobj, types.MethodType)
  782. and isinstance(self.value, torch.nn.Module)
  783. ):
  784. # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup.
  785. # Static lookup can't tell us it's a method or function correctly,
  786. # so we trigger dynamic lookup here to get the correct type.
  787. dynamic_subobj = getattr(self.value, name)
  788. while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"):
  789. subobj = subobj._torchdynamo_inline
  790. dynamic_subobj = subobj
  791. source = AttrSource(source, "_torchdynamo_inline") if source else None
  792. if isinstance(subobj, types.MethodType):
  793. if dynamic_subobj.__self__ is not self.value:
  794. unimplemented("__self__ mismatch for bound method")
  795. func = subobj.__func__
  796. else:
  797. assert isinstance(subobj, types.FunctionType)
  798. func = subobj
  799. if inspect.ismethod(dynamic_subobj):
  800. return variables.UserMethodVariable(func, self, source=source)
  801. elif inspect.isfunction(dynamic_subobj):
  802. if is_utils_checkpoint(func):
  803. return build_checkpoint_variable(source=source)
  804. elif source is not None:
  805. return trace_rules.lookup(func).create_with_source(
  806. func, source=source
  807. )
  808. else:
  809. return trace_rules.lookup(func)(func)
  810. if (
  811. name in getattr(value, "__dict__", {})
  812. or ConstantVariable.is_literal(subobj)
  813. or isinstance(
  814. subobj,
  815. (
  816. torch.Tensor,
  817. torch.nn.Module,
  818. re.Pattern,
  819. ),
  820. )
  821. ):
  822. if source:
  823. install_guard(source.make_guard(GuardBuilder.HASATTR))
  824. return VariableBuilder(tx, source)(subobj)
  825. elif ConstantVariable.is_literal(subobj):
  826. return ConstantVariable.create(subobj)
  827. elif (
  828. type(subobj) == torch.utils._pytree.TreeSpec
  829. or type(subobj) == torch.utils._pytree.LeafSpec
  830. or type(value) == torch.utils._pytree.TreeSpec
  831. ):
  832. from .builder import SourcelessBuilder
  833. return SourcelessBuilder.create(tx, subobj)
  834. if (
  835. name not in getattr(value, "__dict__", {})
  836. and (
  837. type(value).__module__.startswith("torch.")
  838. or isinstance(subobj, re.Pattern)
  839. )
  840. and "torch.optim" not in type(value).__module__
  841. and not callable(value)
  842. and not isinstance(subobj, types.MethodDescriptorType)
  843. ):
  844. if not source:
  845. assert getattr(
  846. importlib.import_module(type(value).__module__),
  847. type(value).__name__,
  848. ) is type(value)
  849. source = AttrSource(
  850. AttrSource(
  851. tx.import_source(type(value).__module__), type(value).__name__
  852. ),
  853. name,
  854. )
  855. return VariableBuilder(tx, source)(subobj)
  856. options = {"source": source}
  857. if isinstance(
  858. subobj,
  859. (
  860. torch.distributions.constraints._Interval,
  861. torch.distributions.constraints._Real,
  862. torch.distributions.constraints.Constraint,
  863. ),
  864. ):
  865. return UserDefinedObjectVariable(subobj, **options)
  866. elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
  867. assert isinstance(subobj, collections.OrderedDict)
  868. if not subobj:
  869. return variables.ConstDictVariable(
  870. subobj, collections.OrderedDict, **options
  871. )
  872. if name == "__class__":
  873. return UserDefinedClassVariable(type(self.value), **options)
  874. return variables.GetAttrVariable(self, name, **options)
  875. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  876. if tx.output.side_effects.is_attribute_mutation(self):
  877. try:
  878. result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
  879. return variables.ConstantVariable.create(
  880. not isinstance(result, variables.DeletedVariable)
  881. )
  882. except KeyError:
  883. pass
  884. if self.source:
  885. install_guard(
  886. AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
  887. )
  888. if self._check_for_getattribute() or self._check_for_getattr():
  889. unimplemented("hasattr with custom __getattr__")
  890. try:
  891. self._getattr_static(name)
  892. return variables.ConstantVariable.create(True)
  893. except AttributeError:
  894. return variables.ConstantVariable.create(False)
  895. def odict_getitem(self, tx, key):
  896. from .builder import VariableBuilder
  897. from .dicts import is_hashable
  898. # TODO this should probably be merged with the dict handling
  899. index = (
  900. key.source
  901. if is_hashable(key) and key.source is not None
  902. else key.as_python_constant()
  903. )
  904. return VariableBuilder(
  905. tx,
  906. ODictGetItemSource(self.source, index),
  907. )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
  908. class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
  909. def __init__(
  910. self,
  911. value,
  912. **kwargs,
  913. ):
  914. super().__init__(value, **kwargs)
  915. def call_method(
  916. self,
  917. tx,
  918. name,
  919. args: "List[VariableTracker]",
  920. kwargs: "Dict[str, VariableTracker]",
  921. ) -> "VariableTracker":
  922. fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
  923. args = [self] + args
  924. return tx.inline_user_function_return(
  925. fn_variable,
  926. args,
  927. kwargs,
  928. )
  929. class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
  930. @staticmethod
  931. def is_matching_object(obj):
  932. mod = sys.modules.get("torchrec.sparse.jagged_tensor")
  933. return mod is not None and type(obj) is mod.KeyedJaggedTensor
  934. def __init__(self, value, **kwargs):
  935. from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
  936. assert type(value) is KeyedJaggedTensor
  937. super().__init__(value, **kwargs)
  938. def var_getattr(self, tx, name):
  939. if (
  940. torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt
  941. and self.source is not None
  942. and name in ("_length_per_key", "_offset_per_key")
  943. ):
  944. with TracingContext.patch(force_unspec_int_unbacked_size_like=True):
  945. return super().var_getattr(tx, name)
  946. return super().var_getattr(tx, name)
  947. class RemovableHandleVariable(VariableTracker):
  948. REMOVED = -1
  949. def __init__(
  950. self,
  951. mutable_local=None,
  952. # index of the registration in the side_effects owned register_hook/handle list, used during removal.
  953. idx=None,
  954. **kwargs,
  955. ):
  956. super().__init__(**kwargs)
  957. self.mutable_local = mutable_local
  958. self.idx = idx
  959. def call_method(self, tx, method_name, args, kwargs):
  960. if method_name == "remove":
  961. if self.idx != self.REMOVED:
  962. tx.output.side_effects.remove_hook(self.idx)
  963. self.idx = self.REMOVED
  964. return variables.ConstantVariable.create(None)
  965. super().call_method(tx, method_name, args, kwargs)
  966. def reconstruct(self, codegen):
  967. if self.idx == self.REMOVED:
  968. # Hook has already been removed, return a dummy handle
  969. codegen.load_import_from("torch._dynamo.utils", "invalid_removeable_handle")
  970. codegen.extend_output(create_call_function(0, True))
  971. return
  972. # unreachable due to codegen.add_cache() when the hook is installed
  973. super().reconstruct(codegen)