tensor.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315
  1. # mypy: ignore-errors
  2. import functools
  3. import inspect
  4. import logging
  5. import operator
  6. import textwrap
  7. import types
  8. import unittest
  9. from typing import Dict, List
  10. import sympy
  11. import torch._numpy as tnp
  12. import torch.fx
  13. import torch.random
  14. from torch._dynamo import compiled_autograd
  15. from torch._subclasses.meta_utils import is_sparse_any
  16. from torch.fx.experimental.symbolic_shapes import (
  17. guard_scalar,
  18. GuardOnDataDependentSymNode,
  19. has_free_symbols,
  20. is_symbolic,
  21. SymTypes,
  22. )
  23. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  24. from .. import config, variables
  25. from .._trace_wrapped_higher_order_op import trace_wrapped
  26. from ..bytecode_transformation import create_call_method
  27. from ..current_scope_id import current_scope_id
  28. from ..exc import unimplemented, UserError, UserErrorType
  29. from ..external_utils import call_hook_from_backward_state
  30. from ..guards import GuardBuilder, install_guard
  31. from ..source import AttrSource
  32. from ..utils import (
  33. fqn,
  34. get_custom_getattr,
  35. get_fake_value,
  36. get_real_value,
  37. guard_if_dyn,
  38. object_has_getattribute,
  39. product,
  40. proxy_args_kwargs,
  41. set_example_value,
  42. tensortype_to_dtype,
  43. )
  44. from .base import _is_top_level_scope, VariableTracker
  45. from .constant import ConstantVariable
  46. from .lists import SizeVariable
  47. try:
  48. import numpy as np
  49. except ModuleNotFoundError:
  50. np = None
  51. log = logging.getLogger(__name__)
  52. # Ops that allow tensor <op> tensor
  53. supported_tensor_comparison_ops = {
  54. ">": operator.gt,
  55. "<": operator.lt,
  56. ">=": operator.ge,
  57. "<=": operator.le,
  58. "==": operator.eq,
  59. "!=": operator.ne,
  60. }
  61. # Ops that allow tensor <op> None
  62. supported_const_comparison_ops = {
  63. "is": operator.is_,
  64. "is not": operator.is_not,
  65. "==": operator.eq,
  66. "!=": operator.ne,
  67. }
  68. supported_comparison_ops = {
  69. **supported_tensor_comparison_ops,
  70. **supported_const_comparison_ops,
  71. }
  72. supported_tensor_comparison_op_values = dict.fromkeys(
  73. supported_tensor_comparison_ops.values()
  74. )
  75. supported_const_comparison_op_values = dict.fromkeys(
  76. supported_const_comparison_ops.values()
  77. )
  78. class TensorVariable(VariableTracker):
  79. """A torch.Tensor input or an intermediate value in the FX graph"""
  80. _nonvar_fields = {
  81. "proxy",
  82. "dtype",
  83. "device",
  84. "layout",
  85. "ndim",
  86. "size",
  87. "stride",
  88. "requires_grad",
  89. "is_quantized",
  90. "is_contiguous",
  91. "is_sparse",
  92. "class_type",
  93. "specialized_value",
  94. "_is_name_set",
  95. *VariableTracker._nonvar_fields,
  96. }
  97. def get_real_value(self):
  98. """
  99. Get the actual value represented by this variable if computation is run
  100. using the user-provided inputs.
  101. NOTE: this runs actual tensor computation and may be
  102. slow and memory-intensive.
  103. """
  104. return get_real_value(self.proxy.node, self.proxy.tracer)
  105. def __init__(
  106. self,
  107. proxy: torch.fx.Proxy,
  108. *,
  109. dtype,
  110. device,
  111. layout,
  112. ndim,
  113. requires_grad,
  114. is_quantized,
  115. is_sparse,
  116. class_type,
  117. has_grad_fn,
  118. size=None,
  119. stride=None,
  120. is_contiguous=None,
  121. _is_name_set=None,
  122. **kwargs,
  123. ):
  124. super().__init__(**kwargs)
  125. self.proxy = proxy
  126. self.dtype = dtype
  127. self.device = device
  128. self.layout = layout
  129. self.ndim = ndim
  130. self.size = size
  131. self.stride = stride
  132. self.requires_grad = requires_grad
  133. self.is_quantized = is_quantized
  134. self.is_contiguous = is_contiguous
  135. self.is_sparse = is_sparse
  136. self.class_type = class_type
  137. self.has_grad_fn = has_grad_fn
  138. if _is_name_set is None:
  139. # no need to rename inputs
  140. _is_name_set = self.proxy.node.op == "placeholder"
  141. self._is_name_set: bool = _is_name_set
  142. def debug_repr(self):
  143. # TODO: strip off fake tensor from repr here
  144. return repr(self.proxy.node.meta["example_value"])
  145. def as_proxy(self):
  146. return self.proxy
  147. def python_type(self):
  148. return self.class_type
  149. @staticmethod
  150. def specialize(value: torch.Tensor):
  151. props = {
  152. "dtype": value.dtype,
  153. "device": value.device,
  154. "layout": value.layout,
  155. "ndim": int(value.ndim),
  156. "requires_grad": value.requires_grad,
  157. "is_quantized": value.is_quantized,
  158. "is_sparse": value.is_sparse,
  159. "class_type": type(value),
  160. }
  161. try:
  162. props["has_grad_fn"] = value.grad_fn is not None
  163. except Exception:
  164. # Workaround for issues with create_parameter_op in Dynamo. Reading
  165. # grad_fn should never cause an issue.
  166. props["has_grad_fn"] = False
  167. if is_sparse_any(value) and not has_free_symbols(value):
  168. props["size"] = tuple(
  169. [int(s) if is_symbolic(s) else s for s in value.size()]
  170. )
  171. elif not has_free_symbols(value):
  172. # this is a fully static shape, and the keys on props here inform specialization.
  173. # We have to cast to int here, because these might get accessed as ConstantVariable, which has
  174. # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
  175. # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
  176. # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
  177. # I'd like to keep it around for now.
  178. props["size"] = tuple(
  179. # the non is_symbolic case applies to the jagged layout
  180. # NestedTensor case as singleton ints are not symbolic
  181. [int(s) if is_symbolic(s) else s for s in value.size()]
  182. )
  183. props["stride"] = tuple(value.stride())
  184. if torch._C._functorch.is_batchedtensor(value):
  185. # Batched tensors does not support contiguity patterns, so
  186. # we refrain from computing the `is_contiguous` property
  187. props["is_contiguous"] = None
  188. else:
  189. props["is_contiguous"] = tuple(
  190. [
  191. x
  192. for x in torch._prims_common._memory_formats
  193. if value.is_contiguous(memory_format=x)
  194. ]
  195. )
  196. return props
  197. def dynamic_getattr(self, tx, name):
  198. fake_val = self.proxy.node.meta["example_value"]
  199. # For getattrs on tensors without sources,
  200. # we can do better than the default (creating a GetAttrVariable)
  201. # if:
  202. # (1) the tensor is a traceable tensor subclass
  203. # (2) We are getattr'ing an inner tensor from that subclass
  204. if not self.source and is_traceable_wrapper_subclass(fake_val):
  205. fake_val = self.proxy.node.meta["example_value"]
  206. attrs, ctx = fake_val.__tensor_flatten__()
  207. proxy = getattr(self.as_proxy(), name)
  208. example_value = getattr(fake_val, name)
  209. if name in attrs:
  210. # attrs returned from tensor_flatten are always tensors
  211. assert isinstance(example_value, torch.Tensor)
  212. from .builder import wrap_fx_proxy
  213. return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value)
  214. # any other attributes on the subclass (that are not methods)
  215. # are assumed to be constant metadata.
  216. elif not callable(example_value):
  217. from .builder import SourcelessBuilder
  218. return SourcelessBuilder.create(tx, example_value)
  219. if not (self.source and self.source.subguards_allowed()):
  220. raise NotImplementedError
  221. # For local source, we associate the real value. We use this real value
  222. # for implementing getattr fallthrough on the variable tracker base class.
  223. # Note - this scope construction is mirrored in guards
  224. # A subsequent PR will introduce a util.
  225. scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
  226. try:
  227. # We raise in case we get a typerror bug w/ SuperSource.
  228. # SuperSource has bugs in it atm, and can produce code like
  229. # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
  230. # L['mod'].model.model.encoder.embed_positions)", scope)
  231. # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
  232. _input_associated_real_value = eval(self.source.name(), scope)
  233. except Exception as exc:
  234. raise NotImplementedError from exc
  235. if _input_associated_real_value is None:
  236. raise NotImplementedError
  237. if object_has_getattribute(_input_associated_real_value):
  238. raise NotImplementedError
  239. if get_custom_getattr(_input_associated_real_value):
  240. raise NotImplementedError
  241. real_value = getattr(_input_associated_real_value, name)
  242. if callable(real_value):
  243. # Callables have more nuanced handling, and we should let the existing system delegate here.
  244. # Raising was past behavior and so should always be sound to fall back.
  245. # Note - at a certain point we may want to handle
  246. raise NotImplementedError
  247. from ..guards import GuardBuilder
  248. from .builder import VariableBuilder
  249. attr_source = AttrSource(self.source, name)
  250. install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
  251. return VariableBuilder(tx, attr_source)(real_value)
  252. def method_attr_ndim(self, tx):
  253. if self.ndim is not None:
  254. return ConstantVariable.create(self.ndim)
  255. else:
  256. return self.call_method(tx, "dim", [], {})
  257. def method_attr_dtype(self, tx):
  258. if self.dtype is not None:
  259. return ConstantVariable.create(self.dtype)
  260. def method_attr_device(self, tx):
  261. if self.device is not None:
  262. return ConstantVariable.create(self.device)
  263. def method_attr_layout(self, tx):
  264. if self.layout is not None:
  265. return ConstantVariable.create(self.layout)
  266. def method_attr_is_cuda(self, tx):
  267. if self.device is not None:
  268. return ConstantVariable.create(self.device.type == "cuda")
  269. def method_attr_shape(self, tx):
  270. if self.size is not None:
  271. sizes = [variables.ConstantVariable.create(x) for x in self.size]
  272. return SizeVariable(sizes)
  273. else:
  274. return self.call_method(tx, "size", [], {})
  275. def method_attr_requires_grad(self, tx):
  276. if self.requires_grad is not None:
  277. return ConstantVariable.create(self.requires_grad)
  278. def method_attr_is_quantized(self, tx):
  279. if self.is_quantized is not None:
  280. return ConstantVariable.create(self.is_quantized)
  281. def method_attr_is_sparse(self, tx):
  282. if self.is_sparse is not None:
  283. return ConstantVariable.create(self.is_sparse)
  284. def method_attr_data(self, tx):
  285. return self.call_method(tx, "detach", [], {})
  286. def method_attr_grad_fn(self, tx):
  287. if self.has_grad_fn:
  288. unimplemented("TensorVariable has a grad_fn")
  289. else:
  290. return variables.ConstantVariable(None)
  291. def method_attr__version(self, tx):
  292. from ..tensor_version_op import _tensor_version
  293. return variables.TorchInGraphFunctionVariable(_tensor_version).call_function(
  294. tx, [self], {}
  295. )
  296. def var_getattr(self, tx, name):
  297. from . import UserDefinedClassVariable
  298. if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
  299. unimplemented(f"Illegal getattr invocation {name} in strict mode")
  300. if name == "__class__":
  301. return UserDefinedClassVariable(self.python_type())
  302. handler = getattr(self, f"method_attr_{name}", None)
  303. result = handler(tx) if handler is not None else None
  304. # Add a guard for type matching, these guards are checked before tensor guards
  305. # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
  306. # <tensor> is later changed to another type
  307. if (
  308. result is not None
  309. and self.source
  310. and self.source.subguards_allowed()
  311. and not (
  312. name not in ("grad", "requires_grad") and result.is_python_constant()
  313. )
  314. ):
  315. install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
  316. result.source = AttrSource(self.source, name)
  317. # It's hard to get inplace view (metadata mutation) on graph input work properly across
  318. # dynamo/aot/inductor, just fall back.
  319. if self.source is not None and hasattr(torch.ops.aten, name):
  320. fn = getattr(torch.ops.aten, name)
  321. if (
  322. hasattr(fn, "overloads")
  323. and hasattr(fn, fn.overloads()[0])
  324. and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
  325. ):
  326. # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
  327. return variables.misc.DelayGraphBreakVariable(
  328. source=AttrSource(self.source, name)
  329. )
  330. # For attributes (not methods) that were not caught in the special handling above,
  331. # (e.g. tensor.real), we handle these generically, assuming that the output type is
  332. # a tensor.
  333. if result is None and name != "grad":
  334. def try_generic_attr_handling():
  335. from .builder import wrap_fx_proxy
  336. from .misc import GetAttrVariable
  337. try:
  338. static_attr = inspect.getattr_static(torch.Tensor, name)
  339. except AttributeError:
  340. return None
  341. # Make sure this is an attribute, not a method.
  342. # type(torch.Tensor.H) should be "getset_descriptor"
  343. # This is a because of CPython implementation, see THPVariableType:
  344. # these attributes are implemented under tp_getset, which appear
  345. # as `getset_descriptor`s, (compared to, say, methods which appear
  346. # as `method_descriptor`s)
  347. if type(static_attr) != types.GetSetDescriptorType:
  348. return None
  349. proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
  350. if self.source is not None:
  351. return wrap_fx_proxy(
  352. tx=tx, proxy=proxy, source=AttrSource(self.source, name)
  353. )
  354. else:
  355. return wrap_fx_proxy(tx=tx, proxy=proxy)
  356. result = try_generic_attr_handling()
  357. if result is None:
  358. result = self.dynamic_getattr(tx, name)
  359. if result is None:
  360. raise NotImplementedError
  361. return result
  362. def has_unpack_var_sequence(self, tx):
  363. return self.ndim > 0
  364. def unpack_var_sequence(self, tx, idxes=None):
  365. from .builder import wrap_fx_proxy_cls
  366. if idxes is None:
  367. if self.size:
  368. length = self.size[0]
  369. else:
  370. dyn_length = self.call_method(
  371. tx, "size", [ConstantVariable.create(0)], {}
  372. )
  373. # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through
  374. # symbolic_shapes, but that end up as int/sympy.Integer
  375. assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
  376. if isinstance(dyn_length, SymNodeVariable):
  377. length = dyn_length.evaluate_expr(tx.output)
  378. else:
  379. length = dyn_length.value
  380. idxes = range(length)
  381. return [
  382. wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i])
  383. for i in idxes
  384. ]
  385. def _strict_mode_banned_ops(self):
  386. return torch._dynamo.config._autograd_backward_strict_mode_banned_ops
  387. def call_method(
  388. self,
  389. tx,
  390. name,
  391. args: "List[VariableTracker]",
  392. kwargs: "Dict[str, VariableTracker]",
  393. ) -> "VariableTracker":
  394. if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
  395. unimplemented(f"Illegal method invocation {name} in strict mode")
  396. """
  397. Dispatch to a method-specific handler defined below. If the
  398. handler returns None (or doesn't exist) we put the method call
  399. in the graph.
  400. """
  401. try:
  402. handler_method = getattr(self, f"method_{name}")
  403. except AttributeError:
  404. pass
  405. else:
  406. try:
  407. result = handler_method(*args, **kwargs)
  408. if result:
  409. return result
  410. except TypeError as e:
  411. unimplemented(f"unhandled args for {name}: {e}")
  412. from .builder import wrap_fx_proxy
  413. return wrap_fx_proxy(
  414. tx,
  415. tx.output.create_proxy(
  416. "call_method",
  417. name,
  418. *proxy_args_kwargs([self, *args], kwargs),
  419. ),
  420. )
  421. def method_size(self, *args, **kwargs):
  422. return self._method_size_stride("size", *args, **kwargs)
  423. def method_stride(self, *args, **kwargs):
  424. return self._method_size_stride("stride", *args, **kwargs)
  425. def _method_size_stride(self, name, dim=None):
  426. dim = guard_if_dyn(dim)
  427. def make_const_size_variable(x, **options):
  428. return SizeVariable(
  429. [ConstantVariable.create(y, **options) for y in x], **options
  430. )
  431. RetVariable = (
  432. make_const_size_variable if name == "size" else ConstantVariable.create
  433. )
  434. # Technically, this should not be necessary, but I'm including it
  435. # for enhanced BC, in case example_value is sometimes not set
  436. # (it really should always be set though!)
  437. if (r := getattr(self, name)) is not None:
  438. if dim is None:
  439. return RetVariable(r)
  440. else:
  441. return ConstantVariable.create(r[dim])
  442. # It might still be constant! Consult the fake tensor and see
  443. if (fake := self.proxy.node.meta.get("example_value")) is not None:
  444. if dim is None:
  445. fake_r = getattr(fake, name)()
  446. if not has_free_symbols(fake_r):
  447. # int conversion for safety, in case a SymInt refined
  448. # to constant
  449. return RetVariable(tuple(int(r) for r in fake_r))
  450. else:
  451. fake_r = getattr(fake, name)(dim)
  452. if not has_free_symbols(fake_r):
  453. return ConstantVariable.create(int(fake_r))
  454. def method_numel(self):
  455. if self.size is not None:
  456. return ConstantVariable.create(product(self.size))
  457. # It might still be constant! Consult the fake tensor and see
  458. if (fake := self.proxy.node.meta.get("example_value")) is not None:
  459. fake_r = fake.numel()
  460. if not has_free_symbols(fake_r):
  461. return ConstantVariable.create(int(fake_r))
  462. method_nelement = method_numel
  463. def method_dim(self):
  464. if self.ndim is not None:
  465. return ConstantVariable.create(self.ndim)
  466. method_ndimension = method_dim
  467. def method_is_floating_point(self):
  468. if self.dtype is not None:
  469. return ConstantVariable.create(self.dtype.is_floating_point)
  470. def method_is_complex(self):
  471. if self.dtype is not None:
  472. return ConstantVariable.create(self.dtype.is_complex)
  473. def method_is_contiguous(self, memory_format=None):
  474. memory_format = (
  475. memory_format.as_python_constant()
  476. if memory_format is not None
  477. else torch.contiguous_format
  478. )
  479. if self.is_contiguous is not None:
  480. return ConstantVariable.create(memory_format in self.is_contiguous)
  481. elif (fake := self.proxy.node.meta.get("example_value")) is not None:
  482. return ConstantVariable.create(
  483. fake.is_contiguous(memory_format=memory_format)
  484. )
  485. def method_type(self, dtype=None, non_blocking=False, **kwargs):
  486. if (
  487. dtype is None
  488. and self.dtype is not None
  489. and isinstance(self.device, torch.device)
  490. ):
  491. tensortype = next(
  492. k for k, v in tensortype_to_dtype.items() if self.dtype in v
  493. )
  494. if self.device.type == "cuda":
  495. return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}")
  496. else:
  497. return ConstantVariable.create(f"torch.{tensortype.__name__}")
  498. elif (
  499. dtype is not None
  500. and fqn(type(dtype.as_python_constant())) == "torch.tensortype"
  501. ):
  502. # torch.FloatTensor, etc. are all of type "torch.tensortype".
  503. # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
  504. # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
  505. tensor_type = dtype.as_python_constant()
  506. tensor_type_const = ConstantVariable.create(fqn(tensor_type))
  507. from ..symbolic_convert import InstructionTranslator
  508. from .builder import wrap_fx_proxy
  509. tx = InstructionTranslator.current_tx()
  510. if non_blocking:
  511. kwargs = {"non_blocking": non_blocking, **kwargs}
  512. return wrap_fx_proxy(
  513. tx,
  514. tx.output.create_proxy(
  515. "call_method",
  516. "type",
  517. *proxy_args_kwargs([self, tensor_type_const], kwargs),
  518. ),
  519. )
  520. def method_as_subclass(self, cls):
  521. if isinstance(cls, TensorSubclassVariable) and cls.source:
  522. from ..symbolic_convert import InstructionTranslator
  523. from .builder import VariableBuilder
  524. from .torch_function import TensorWithTFOverrideVariable
  525. tx = InstructionTranslator.current_tx()
  526. # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable
  527. # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass
  528. # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
  529. # It is up to the user whether this is correct behavior or not.
  530. py_cls = cls.as_python_constant()
  531. torch_fn = VariableBuilder(
  532. tx,
  533. AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
  534. )(py_cls.__torch_function__.__func__)
  535. return TensorWithTFOverrideVariable.from_tensor_var(
  536. tx, self, py_cls, torch_fn
  537. )
  538. def method_get_device(self):
  539. if isinstance(self.device, torch.device):
  540. index = self.device.index if self.device.type != "cpu" else -1
  541. return ConstantVariable.create(index)
  542. def method_element_size(self):
  543. return ConstantVariable.create(self.dtype.itemsize)
  544. def method_numpy(self, *, force=False):
  545. if not config.trace_numpy:
  546. unimplemented("Tensor.numpy(). config.trace_numpy is False")
  547. if not np:
  548. unimplemented("Tensor.numpy(). NumPy is not available")
  549. if self.layout != torch.strided:
  550. raise TypeError(
  551. f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
  552. )
  553. from ..symbolic_convert import InstructionTranslator
  554. tx = InstructionTranslator.current_tx()
  555. # We don't check that the tensor is on CPU when force is False, as this
  556. # allows us to execute NumPy code on CUDA. Same for requires_grad=True
  557. if force and force.as_python_constant():
  558. # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...)
  559. t = self.call_method(tx, "detach", [], {})
  560. proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {})
  561. else:
  562. # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable
  563. proxy = tx.output.create_proxy(
  564. "call_method", "view_as", *proxy_args_kwargs([self, self], {})
  565. )
  566. return NumpyNdarrayVariable.create(tx, proxy)
  567. def method_tolist(self):
  568. from ..symbolic_convert import InstructionTranslator
  569. from .builder import SourcelessBuilder
  570. tx = InstructionTranslator.current_tx()
  571. def tolist(tensor, sub_proxy):
  572. def wrap(i, sub_proxy):
  573. # Sigh, we forgot to gate this, so this data dependent is on
  574. # by default and is load bearing in CI
  575. with unittest.mock.patch.object(
  576. tx.fake_mode, "allow_scalar_outputs", True
  577. ):
  578. return SymNodeVariable.create(
  579. tx,
  580. sub_proxy.item(),
  581. )
  582. if tensor.dtype not in [
  583. torch.int8,
  584. torch.int16,
  585. torch.int32,
  586. torch.int64,
  587. ]:
  588. unimplemented("Input tensor for tolist must be an integer tensor")
  589. if tensor.dim() == 0:
  590. return wrap(tensor, sub_proxy)
  591. if tensor.dim() == 1:
  592. return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
  593. return [
  594. tolist(sub_tensor, sub_proxy=sub_proxy[i])
  595. for i, sub_tensor in enumerate(tensor)
  596. ]
  597. tensor = self.as_proxy().node.meta["example_value"]
  598. out = tolist(tensor, self.as_proxy())
  599. return SourcelessBuilder.create(tx, out)
  600. def method_backward(self, *args, **kwargs):
  601. unimplemented("Tensor.backward")
  602. def method_data_ptr(self, *args, **kwargs):
  603. unimplemented("Tensor.data_ptr")
  604. def method_item(self, *args, **kwargs):
  605. if not config.capture_scalar_outputs:
  606. self._warn_capture_scalar_outputs()
  607. unimplemented("Tensor.item")
  608. @staticmethod
  609. @functools.lru_cache(None)
  610. def _warn_capture_scalar_outputs():
  611. log.warning(
  612. textwrap.dedent(
  613. """\
  614. Graph break from `Tensor.item()`, consider setting:
  615. torch._dynamo.config.capture_scalar_outputs = True
  616. or:
  617. env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
  618. to include these operations in the captured graph.
  619. """
  620. )
  621. )
  622. def method___len__(self):
  623. from ..symbolic_convert import InstructionTranslator
  624. tx = InstructionTranslator.current_tx()
  625. return self.call_method(tx, "size", [ConstantVariable.create(0)], {})
  626. def method___setitem__(self, key, value):
  627. def has_bool_key(v):
  628. if isinstance(v, TensorVariable):
  629. return v.dtype in (torch.bool, torch.int8)
  630. elif isinstance(v, variables.TupleVariable):
  631. return any(has_bool_key(item) for item in v.items)
  632. else:
  633. return False
  634. if (
  635. has_bool_key(key)
  636. and isinstance(value, TensorVariable)
  637. and value.requires_grad
  638. and torch.is_grad_enabled()
  639. ):
  640. unimplemented(
  641. "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123"
  642. )
  643. from ..symbolic_convert import InstructionTranslator
  644. tx = InstructionTranslator.current_tx()
  645. tx.output.create_proxy(
  646. "call_function",
  647. operator.setitem,
  648. *proxy_args_kwargs([self, key, value], {}),
  649. )
  650. return ConstantVariable.create(None)
  651. def method_resize_(self, *args, **kwargs):
  652. unimplemented("Tensor.resize_")
  653. def method_resize_as_(self, *args, **kwargs):
  654. unimplemented("Tensor.resize_as_")
  655. def method_sparse_resize_(self, *args, **kwargs):
  656. unimplemented("Tensor.sparse_resize_")
  657. def method_sparse_resize_and_clear_(self, *args, **kwargs):
  658. unimplemented("Tensor.sparse_resize_and_clear_")
  659. def method_set_(self, *args, **kwargs):
  660. if len(args) > 1:
  661. # torch.Tensor.set_() has several overloads.
  662. # aten::set_.source_Tensor(Tensor) gets special handling
  663. # in AOTAutograd and functionalization, because it is the most common
  664. # overload and is used by FSDP.
  665. # graph-breaking on aten::set_source_Tensor_storage_offset for now,
  666. # unless we find that we need to make it work.
  667. unimplemented("Tensor.set_.source_Tensor_storage_offset")
  668. def method_add_(self, other, *, alpha=None):
  669. if alpha is not None:
  670. from ..symbolic_convert import InstructionTranslator
  671. tx = InstructionTranslator.current_tx()
  672. result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
  673. tx, [other, alpha], {}
  674. )
  675. return self.call_method(tx, "add_", [result], {})
  676. def method_addcdiv_(self, tensor1, tensor2, *, value=None):
  677. from ..symbolic_convert import InstructionTranslator
  678. tx = InstructionTranslator.current_tx()
  679. if value is not None:
  680. result = variables.TorchInGraphFunctionVariable(torch.div).call_function(
  681. tx, [tensor1, tensor2], {}
  682. )
  683. result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
  684. tx, [result, value], {}
  685. )
  686. return self.call_method(tx, "add_", [result], {})
  687. def method___contains__(self, arg):
  688. from ..symbolic_convert import InstructionTranslator
  689. tx = InstructionTranslator.current_tx()
  690. # Rewrite __contains__ here so that downstream passes can trace through
  691. # without dealing with unbacked symbool. Roughly the code we translate is:
  692. # def __contains__(self, x):
  693. # return (x == self).any().item()
  694. result = variables.TorchInGraphFunctionVariable(torch.eq).call_function(
  695. tx, [self, arg], {}
  696. )
  697. result = variables.TorchInGraphFunctionVariable(torch.any).call_function(
  698. tx, [result], {}
  699. )
  700. return result.call_method(tx, "item", [], {})
  701. def method_redistribute(self, *args, **kwargs):
  702. from ..symbolic_convert import InstructionTranslator
  703. tx = InstructionTranslator.current_tx()
  704. # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
  705. # and rewrite args to have only proxyable args, then insert call_function
  706. args_as_value = [x.as_python_constant() for x in args]
  707. kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
  708. def redistribute_fn_with_prim_types(x):
  709. return x.redistribute(*args_as_value, **kwargs_as_value)
  710. # attach the same function name for better debugging
  711. redistribute_fn_with_prim_types.__name__ = "prim_redistribute"
  712. from .builder import wrap_fx_proxy
  713. return wrap_fx_proxy(
  714. tx=tx,
  715. proxy=tx.output.create_proxy(
  716. "call_function",
  717. redistribute_fn_with_prim_types,
  718. *proxy_args_kwargs([self], {}),
  719. ),
  720. )
  721. def method_to_local(self, *args, **kwargs):
  722. from ..symbolic_convert import InstructionTranslator
  723. tx = InstructionTranslator.current_tx()
  724. # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
  725. # and rewrite args to have only proxyable args, then insert call_function
  726. args_as_value = [x.as_python_constant() for x in args]
  727. kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
  728. def to_local_fn_with_prim_types(x):
  729. return x.to_local(*args_as_value, **kwargs_as_value)
  730. # attach the same function name for better debugging
  731. to_local_fn_with_prim_types.__name__ = "prim_to_local"
  732. from .builder import wrap_fx_proxy
  733. return wrap_fx_proxy(
  734. tx=tx,
  735. proxy=tx.output.create_proxy(
  736. "call_function",
  737. to_local_fn_with_prim_types,
  738. *proxy_args_kwargs([self], {}),
  739. ),
  740. )
  741. def method_register_hook(self, *args, **kwargs):
  742. return self._method_register_hook("register_hook", *args, **kwargs)
  743. def method_register_post_accumulate_grad_hook(self, *args, **kwargs):
  744. return self._method_register_hook(
  745. "register_post_accumulate_grad_hook", *args, **kwargs
  746. )
  747. def _method_register_hook(self, name: str, hook: VariableTracker):
  748. # Note - do not arbitrarily add hooks here - make sure they match the same contract
  749. # see [On tensor.register_hook]
  750. from ..symbolic_convert import InstructionTranslator
  751. tx = InstructionTranslator.current_tx()
  752. if not self.source:
  753. if not compiled_autograd.compiled_autograd_enabled:
  754. # TODO(voz):
  755. # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
  756. # python state.
  757. # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run
  758. # them in a compiled bwd without re-entering dynamo as compiled_autograd does.
  759. #
  760. # Discussion point 1 - Should we bypass this if nopython/fullgraph = True?
  761. # No. Because this was going to be a graph break anyway - this check does not
  762. # introduce new graph breaks where there were none.
  763. #
  764. # Discussion point 2 - Should we defer this check to backwards?
  765. # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user
  766. # would have no recourse - their forward traces just fine, but will fail at backwards unless
  767. # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today)
  768. # then they have nothing they can do except disable compile.
  769. unimplemented(
  770. "Compilation of intermediate hooks requires compiled autograd"
  771. )
  772. hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook)
  773. def _register_hook_trampoline(tensor, bw_state):
  774. register_hook = getattr(tensor, name)
  775. register_hook(
  776. functools.partial(
  777. trace_wrapped,
  778. fn=call_hook_from_backward_state,
  779. bw_state=bw_state,
  780. hook_name=hook_name,
  781. )
  782. )
  783. # TODO(jansel): returning None here is wrong, it should be
  784. # RemovableHandle, but we need some extra work to support
  785. # this properly.
  786. return None
  787. from .builder import wrap_fx_proxy
  788. return wrap_fx_proxy(
  789. tx,
  790. tx.output.create_proxy(
  791. "call_function",
  792. _register_hook_trampoline,
  793. (self.as_proxy(), bw_state_proxy),
  794. {},
  795. ),
  796. )
  797. handle_variable = variables.RemovableHandleVariable(
  798. mutable_local=variables.base.MutableLocal(),
  799. )
  800. tx.output.side_effects.register_hook(self, hook, handle_variable, name)
  801. return handle_variable
  802. def method_requires_grad_(self, requires_grad=True):
  803. if requires_grad is not True:
  804. requires_grad = requires_grad.as_python_constant()
  805. if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad:
  806. unimplemented("Tensor.requires_grad_")
  807. else:
  808. return self
  809. def method_new(self, *args, **kwargs):
  810. # Convert x.new(torch.Size) into x.new_empty(torch.Size),
  811. # as Tensor.new acts differently with a Size input versus a tuple input.
  812. if (len(args) == 1 and isinstance(args[0], SizeVariable)) or (
  813. len(args) >= 1
  814. and all(
  815. isinstance(a, ConstantVariable) and a.python_type() == int for a in args
  816. )
  817. ):
  818. from ..symbolic_convert import InstructionTranslator
  819. return self.call_method(
  820. InstructionTranslator.current_tx(), "new_empty", args, kwargs
  821. )
  822. def method_untyped_storage(self):
  823. return UntypedStorageVariable(
  824. self, self.as_proxy().node.meta["example_value"].untyped_storage()
  825. )
  826. def set_name_hint(self, name: str):
  827. # Only rename at the top-level scope, this is to avoid the confusion between
  828. # mutating a variable vs renaming it (e.g. a = b) during speculating a higher order op,
  829. # where mutation is prohibited and it's difficult to differentiate it with renaming.
  830. if not self._is_name_set and _is_top_level_scope(current_scope_id()):
  831. self.proxy.node._rename(name)
  832. self._is_name_set = True
  833. class SymNodeVariable(VariableTracker):
  834. """
  835. Represents a symbolic scalar, either int, float or bool. This is most commonly used to
  836. handle symbolic size computation, e.g., tensor.size(0), but it is also used to
  837. handle logic like float_tensor.item() or unspecialized float inputs.
  838. """
  839. _nonvar_fields = {
  840. "proxy",
  841. "sym_num",
  842. *VariableTracker._nonvar_fields,
  843. }
  844. def debug_repr(self):
  845. return repr(self.sym_num)
  846. @classmethod
  847. def create(cls, tx, proxy, sym_num=None, **options):
  848. if sym_num is None:
  849. sym_num = get_fake_value(proxy.node, tx)
  850. if "example_value" in proxy.node.meta:
  851. assert proxy.node.meta["example_value"] == sym_num
  852. set_example_value(proxy.node, sym_num)
  853. if isinstance(sym_num, (sympy.Integer, int, bool)):
  854. sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num
  855. return ConstantVariable.create(sym_num)
  856. return SymNodeVariable(proxy, sym_num, **options)
  857. def __init__(self, proxy, sym_num, **kwargs):
  858. super().__init__(**kwargs)
  859. self.proxy = proxy
  860. # TODO: Should we allow non SymTypes here? Today it is allowed
  861. self.sym_num = sym_num
  862. self._tensor_var = None
  863. def python_type(self):
  864. if isinstance(self.sym_num, SymTypes):
  865. return self.sym_num.node.pytype
  866. else:
  867. return type(self.sym_num)
  868. def as_proxy(self):
  869. return self.proxy
  870. def as_tensor(self, tx):
  871. if self._tensor_var is None:
  872. from .builder import SourcelessBuilder
  873. self._tensor_var = SourcelessBuilder.create(
  874. tx, torch.scalar_tensor
  875. ).call_function(tx, [self], {})
  876. return self._tensor_var
  877. def evaluate_expr(self, output_graph=None):
  878. try:
  879. return guard_scalar(self.sym_num)
  880. except GuardOnDataDependentSymNode as e:
  881. raise UserError( # noqa: B904
  882. UserErrorType.ANTI_PATTERN,
  883. f"Consider annotating your code using torch._check*(). {str(e)}",
  884. case_name="constrain_as_size_example",
  885. )
  886. def call_method(
  887. self,
  888. tx,
  889. name,
  890. args: "List[VariableTracker]",
  891. kwargs: "Dict[str, VariableTracker]",
  892. ) -> "VariableTracker":
  893. from .builder import wrap_fx_proxy
  894. return wrap_fx_proxy(
  895. tx,
  896. tx.output.create_proxy(
  897. "call_method",
  898. name,
  899. *proxy_args_kwargs([self, *args], kwargs),
  900. ),
  901. )
  902. class NumpyNdarrayVariable(TensorVariable):
  903. """
  904. Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray.
  905. Use this for Tensor.numpy() call.
  906. """
  907. @staticmethod
  908. def create(tx, proxy, **options):
  909. from .builder import wrap_fx_proxy_cls
  910. return wrap_fx_proxy_cls(
  911. target_cls=NumpyNdarrayVariable,
  912. tx=tx,
  913. proxy=proxy,
  914. **options,
  915. )
  916. def var_getattr(self, tx, name):
  917. # NB: This INTENTIONALLY does not call super(), because there is
  918. # no intrinsic reason ndarray properties are related to Tensor
  919. # properties. The inheritance here is for implementation sharing.
  920. from ..utils import numpy_attr_wrapper
  921. from .builder import wrap_fx_proxy
  922. result = None
  923. example_value = self.as_proxy().node.meta["example_value"]
  924. example_ndarray = tnp.ndarray(example_value)
  925. def insert_into_graph():
  926. return wrap_fx_proxy(
  927. tx,
  928. tx.output.create_proxy(
  929. "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
  930. ),
  931. )
  932. if name in ["T", "real", "imag"]:
  933. proxy = tx.output.create_proxy(
  934. "call_function",
  935. numpy_attr_wrapper,
  936. (self.as_proxy(), name),
  937. {},
  938. )
  939. result = NumpyNdarrayVariable.create(tx, proxy)
  940. # These are awkward to implement. The standard playbook for torch._numpy
  941. # interop is to trace a call into the torch._numpy wrapper which works for
  942. # Tensor operations. However, we don't want to do this for calls
  943. # that don't return Tensors, because in those cases we may not want
  944. # to trace the attribute access into the graph at all (it is sort
  945. # of harmless to do so, because AOTAutograd will eliminate them,
  946. # but it's best not to trace them in to begin with.) But in any
  947. # case, tracing these into the graph is like trying to fit a square
  948. # peg into a round hole; best not to do it. So instead we
  949. # painstakingly implement these by hand
  950. #
  951. # NB: only ALWAYS specialized attributes can go here; notably,
  952. # size/shape not allowed!
  953. elif name in ("ndim", "itemsize"):
  954. return ConstantVariable.create(getattr(example_ndarray, name))
  955. elif name in ("shape", "stride"):
  956. if not has_free_symbols(r := getattr(example_ndarray, name)):
  957. return ConstantVariable.create(tuple(int(r) for r in r))
  958. return insert_into_graph()
  959. elif name == "size":
  960. if not has_free_symbols(r := example_ndarray.size):
  961. return ConstantVariable.create(int(r))
  962. return insert_into_graph()
  963. elif name in ["base", "flags", "dtype"]:
  964. unimplemented(f"TODO: add support for ndarray.{name}")
  965. elif name in ["__version__"]:
  966. unimplemented("delegate np.__version__ to NumPy")
  967. if result is None:
  968. raise NotImplementedError
  969. return result
  970. @staticmethod
  971. def patch_args(name, args, kwargs):
  972. if name == "clip":
  973. kwargs_rename = {"a_min": "min", "a_max": "max"}
  974. kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()}
  975. return args, kwargs
  976. def call_method(
  977. self,
  978. tx,
  979. name,
  980. args: "List[VariableTracker]",
  981. kwargs: "Dict[str, VariableTracker]",
  982. ) -> "VariableTracker":
  983. from ..utils import numpy_method_wrapper
  984. args, kwargs = self.patch_args(name, args, kwargs)
  985. if name in ["__len__", "size", "tolist"]:
  986. # delegate back to TensorVariable
  987. return super().call_method(tx, name, args, kwargs)
  988. if name == "tobytes":
  989. unimplemented("tobytes is not modelled in torch._numpy")
  990. proxy = tx.output.create_proxy(
  991. "call_function",
  992. numpy_method_wrapper(name),
  993. *proxy_args_kwargs([self] + list(args), kwargs),
  994. )
  995. return NumpyNdarrayVariable.create(tx, proxy)
  996. def python_type(self):
  997. return np.ndarray
  998. class UnspecializedPythonVariable(TensorVariable):
  999. """
  1000. This is a 1-element tensor represents unspecialized python float/int.
  1001. """
  1002. _nonvar_fields = {
  1003. "raw_value",
  1004. "need_unwrap",
  1005. *TensorVariable._nonvar_fields,
  1006. }
  1007. def __init__(
  1008. self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs
  1009. ):
  1010. super().__init__(proxy, **kwargs)
  1011. self.raw_value = raw_value
  1012. self.need_unwrap = need_unwrap
  1013. @classmethod
  1014. def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
  1015. # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
  1016. return UnspecializedPythonVariable(
  1017. **dict(tensor_variable.__dict__),
  1018. raw_value=raw_value,
  1019. need_unwrap=need_unwrap,
  1020. )
  1021. class FakeItemVariable(TensorVariable):
  1022. """An unspecialized python variable which prevents access to the underlying raw value.
  1023. This is needed if item is called on a FakeTensor."""
  1024. _nonvar_fields = {
  1025. "need_unwrap",
  1026. *TensorVariable._nonvar_fields,
  1027. }
  1028. def __init__(self, proxy: torch.fx.Proxy, **kwargs):
  1029. need_unwrap = kwargs.pop("need_unwrap", False)
  1030. super().__init__(proxy, **kwargs)
  1031. self.need_unwrap = need_unwrap
  1032. @classmethod
  1033. def from_tensor_variable(cls, tensor_variable):
  1034. return FakeItemVariable(**dict(tensor_variable.__dict__))
  1035. class TensorSubclassVariable(VariableTracker):
  1036. def __init__(self, value, *args, **kwargs):
  1037. self.value = value
  1038. super().__init__(*args, **kwargs)
  1039. def call_function(
  1040. self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
  1041. ) -> VariableTracker:
  1042. if len(args) == 1 and isinstance(args[0], TensorVariable):
  1043. from .builder import VariableBuilder
  1044. from .torch_function import TensorWithTFOverrideVariable
  1045. torch_fn = VariableBuilder(
  1046. tx, AttrSource(self.source, "__torch_function__")
  1047. )(self.value.__torch_function__)
  1048. return TensorWithTFOverrideVariable.from_tensor_var(
  1049. tx, args[0], self.value, torch_fn
  1050. )
  1051. return super().call_function(tx, args, kwargs)
  1052. def as_python_constant(self):
  1053. return self.value
  1054. def python_type(self):
  1055. return type(self.value)
  1056. class UntypedStorageVariable(VariableTracker):
  1057. _nonvar_fields = {
  1058. "example_value",
  1059. *VariableTracker._nonvar_fields,
  1060. }
  1061. def __init__(
  1062. self,
  1063. from_tensor: TensorVariable,
  1064. example_value: torch.UntypedStorage,
  1065. **kwargs,
  1066. ):
  1067. super().__init__(**kwargs),
  1068. self.from_tensor = from_tensor
  1069. # Example_value will always have device="meta"
  1070. self.example_value = example_value
  1071. def call_method(
  1072. self,
  1073. tx,
  1074. name,
  1075. args: List[VariableTracker],
  1076. kwargs: Dict[str, VariableTracker],
  1077. ) -> VariableTracker:
  1078. if name == "size":
  1079. assert not args
  1080. assert not kwargs
  1081. result = self.example_value.size()
  1082. if not has_free_symbols(result):
  1083. # avoid creating a node in the graph
  1084. return ConstantVariable.create(int(result))
  1085. else:
  1086. from ..external_utils import untyped_storage_size
  1087. from .builder import wrap_fx_proxy
  1088. return wrap_fx_proxy(
  1089. tx,
  1090. tx.output.create_proxy(
  1091. "call_function",
  1092. untyped_storage_size,
  1093. (self.from_tensor.as_proxy(),),
  1094. {},
  1095. ),
  1096. )
  1097. if name == "resize_" and len(args) == 1:
  1098. assert not kwargs
  1099. tx.output.create_proxy(
  1100. "call_function",
  1101. torch.ops.inductor.resize_storage_bytes_,
  1102. (self.from_tensor.as_proxy(), args[0].as_proxy()),
  1103. {},
  1104. )
  1105. return self
  1106. return super().call_method(tx, name, args, kwargs)
  1107. def reconstruct(self, codegen):
  1108. codegen(self.from_tensor)
  1109. codegen.append_output(codegen.create_load_method("untyped_storage"))
  1110. codegen.extend_output(create_call_method(0))