misc.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205
  1. # mypy: ignore-errors
  2. import collections
  3. import dataclasses
  4. import functools
  5. import inspect
  6. import itertools
  7. import re
  8. import sys
  9. import types
  10. from typing import Dict, List
  11. import torch._C
  12. import torch._numpy as tnp
  13. import torch.utils._pytree as pytree
  14. from .. import config, variables
  15. from ..bytecode_transformation import create_call_function, create_instruction
  16. from ..exc import unimplemented
  17. from ..guards import GuardBuilder, install_guard
  18. from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
  19. from ..utils import (
  20. check_unspec_or_constant_args,
  21. identity,
  22. is_tensor_base_attr_getter,
  23. proxy_args_kwargs,
  24. set_example_value,
  25. )
  26. from .base import VariableTracker
  27. from .functions import NestedUserFunctionVariable, UserFunctionVariable
  28. from .user_defined import is_standard_setattr, UserDefinedObjectVariable
  29. class SuperVariable(VariableTracker):
  30. _nonvar_fields = {
  31. "specialized",
  32. *VariableTracker._nonvar_fields,
  33. }
  34. def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
  35. super().__init__(**kwargs)
  36. # typevar is the fist argument to super(). In the case where no argument
  37. # is provided to super(), it is the __class__ object where
  38. # the super() function is being called
  39. self.typevar = typevar
  40. # objvar here must be an instance or subtype of typevar.
  41. # In the case where super() is called without arguments, it is the first argument
  42. # to the current function where super() is called from (self for regular method,
  43. # cls for a classmethod)
  44. self.objvar = objvar
  45. self.specialized = specialized # directly get attr from self.typevar if true
  46. def reconstruct(self, codegen):
  47. codegen(variables.BuiltinVariable(super))
  48. codegen(self.typevar)
  49. if self.objvar is not None:
  50. codegen(self.objvar)
  51. codegen.extend_output(create_call_function(2, True))
  52. else:
  53. codegen.extend_output(create_call_function(1, True))
  54. def _resolved_getattr_and_source(self, tx, name):
  55. assert self.objvar, "1-arg super not implemented"
  56. if self.specialized:
  57. return getattr(self.typevar.as_python_constant(), name)
  58. search_type = self.typevar.as_python_constant()
  59. # The rest of this function does two things:
  60. # - Walk the mro to find where the attribute comes from to be
  61. # able to provide accurate source
  62. # - Call the getattr to get the object
  63. # Find the class object, where the function lives.
  64. # When objvar is "self", use type(self), when objvar is "cls", use it as-is
  65. type_to_use = self.objvar.python_type()
  66. type_to_use_source = (
  67. TypeSource(self.objvar.source) if self.objvar.source else None
  68. )
  69. if issubclass(type_to_use, type):
  70. type_to_use = self.objvar.value
  71. type_to_use_source = self.objvar.source
  72. source = None
  73. if self.objvar.source is not None:
  74. # Walk the mro tuple to find out the actual class where the
  75. # attribute resides.
  76. search_mro = type_to_use.__mro__
  77. start_index = search_mro.index(search_type) + 1
  78. for index in range(start_index, len(search_mro)):
  79. if hasattr(search_mro[index], name):
  80. # Equivalent of something like type(L['self']).__mro__[1].attr_name
  81. source = AttrSource(
  82. GetItemSource(AttrSource(type_to_use_source, "__mro__"), index),
  83. name,
  84. )
  85. break
  86. # TODO(jansel): there is a small chance this could trigger user code, prevent that
  87. return getattr(super(search_type, type_to_use), name), source
  88. def var_getattr(self, tx, name: str) -> "VariableTracker":
  89. # Check if getattr is a constant. If not, delay the actual work by
  90. # wrapping the result in GetAttrVariable. Mostly super is called with a
  91. # method, so most of the work is delayed to call_function.
  92. #
  93. # We could have just implemented a const_getattr. However, super is
  94. # special when it comes to finding sources. Compared to other VTs, super
  95. # requires the attr name to walk the mro and find the actual source (and
  96. # not just AttrSource).
  97. value, source = self._resolved_getattr_and_source(self, name)
  98. if not variables.ConstantVariable.is_literal(value):
  99. return GetAttrVariable(self, name)
  100. if source:
  101. install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
  102. return variables.ConstantVariable.create(value, source=source)
  103. return variables.ConstantVariable.create(value)
  104. def call_method(
  105. self,
  106. tx,
  107. name,
  108. args: "List[VariableTracker]",
  109. kwargs: "Dict[str, VariableTracker]",
  110. ) -> "VariableTracker":
  111. inner_fn, source = self._resolved_getattr_and_source(self, name)
  112. if inner_fn is object.__init__:
  113. return LambdaVariable(identity)
  114. elif inner_fn is torch.nn.Module.__init__:
  115. objvar = self.objvar
  116. from ..side_effects import AttributeMutationNew
  117. if (
  118. isinstance(objvar, variables.UserDefinedObjectVariable)
  119. and isinstance(objvar.mutable_local, AttributeMutationNew)
  120. and not (args or kwargs)
  121. ):
  122. tx.output.side_effects.store_attr(
  123. objvar,
  124. "__call_nn_module_init",
  125. variables.ConstantVariable.create(True),
  126. )
  127. return variables.ConstantVariable.create(None)
  128. else:
  129. unimplemented("super() nn.Module.__init__")
  130. elif isinstance(inner_fn, types.FunctionType):
  131. return variables.UserFunctionVariable(
  132. inner_fn, source=source
  133. ).call_function(tx, [self.objvar] + args, kwargs)
  134. elif isinstance(inner_fn, types.MethodType):
  135. return variables.UserMethodVariable(
  136. inner_fn.__func__, self.objvar, source=source
  137. ).call_function(tx, args, kwargs)
  138. elif (
  139. inner_fn is collections.OrderedDict.__getitem__
  140. and isinstance(self.objvar, variables.UserDefinedObjectVariable)
  141. and self.objvar.source
  142. and len(args) == 1
  143. and len(kwargs) == 0
  144. and args[0].is_python_constant()
  145. ):
  146. from .builder import VariableBuilder
  147. key = args[0].as_python_constant()
  148. return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
  149. collections.OrderedDict.__getitem__(self.objvar.value, key)
  150. )
  151. elif inner_fn in (
  152. collections.OrderedDict.__setitem__,
  153. object.__setattr__,
  154. ) and isinstance(self.objvar, variables.CustomizedDictVariable):
  155. assert not kwargs and len(args) == 2
  156. return super(variables.CustomizedDictVariable, self.objvar).call_method(
  157. tx, "__setitem__", args, kwargs
  158. )
  159. elif is_standard_setattr(inner_fn) and isinstance(
  160. self.objvar, UserDefinedObjectVariable
  161. ):
  162. return self.objvar.method_setattr_standard(tx, *args, **kwargs)
  163. unimplemented(f"non-function or method super: {inner_fn}")
  164. class ExceptionVariable(VariableTracker):
  165. def __init__(self, exc_type, args, **kwargs):
  166. super().__init__(**kwargs)
  167. self.exc_type = exc_type
  168. self.args = args
  169. def reconstruct(self, codegen):
  170. codegen.load_import_from("builtins", self.exc_type.__name__)
  171. codegen.foreach(self.args)
  172. codegen.call_function(len(self.args), True)
  173. class UnknownVariable(VariableTracker):
  174. """
  175. It could be anything!
  176. """
  177. class DelayGraphBreakVariable(UnknownVariable):
  178. """
  179. Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
  180. """
  181. class ComptimeVariable(VariableTracker):
  182. """
  183. This variable is special, it lets you execute arbitrary code at
  184. Dynamo compile time
  185. """
  186. def reconstruct(self, codegen):
  187. raise NotImplementedError("comptime is special form")
  188. def var_getattr(self, tx, name: str) -> "VariableTracker":
  189. from ..comptime import comptime
  190. # To support the comptime.print_graph convenience accessors
  191. from .functions import UserFunctionVariable
  192. return UserFunctionVariable(
  193. getattr(comptime, name), source=AttrSource(self.source, name)
  194. )
  195. def call_function(
  196. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  197. ) -> "VariableTracker":
  198. from ..comptime import ComptimeContext
  199. # TODO: support an expression form as well
  200. assert not kwargs
  201. # Second argument is runtime lambda, ignored
  202. assert len(args) <= 2
  203. fn = args[0]
  204. if isinstance(fn, UserFunctionVariable):
  205. fn.get_function()(ComptimeContext(tx))
  206. elif isinstance(fn, NestedUserFunctionVariable):
  207. # We have to manually bind the freevars ourselves
  208. code = fn.get_code()
  209. assert not fn.closure, (
  210. "comptime function must not have free variables, "
  211. f"but these variables were free: {code.co_freevars}"
  212. )
  213. func = types.FunctionType(
  214. code,
  215. fn.f_globals,
  216. fn.fn_name.as_python_constant(),
  217. tuple(fn.defaults.items) if fn.defaults else None,
  218. # We could automatically promote free variables into
  219. # ComptimeVar but this is confusing if you access
  220. # a free variable that we actually DO have the runtime
  221. # value for
  222. # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
  223. tuple(),
  224. )
  225. func(ComptimeContext(tx))
  226. else:
  227. raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
  228. return variables.ConstantVariable.create(None)
  229. class ClosureVariable(UnknownVariable):
  230. _nonvar_fields = {
  231. "name",
  232. *UnknownVariable._nonvar_fields,
  233. }
  234. def __init__(self, name, **kwargs):
  235. super().__init__(**kwargs)
  236. self.name = name
  237. def reconstruct(self, codegen):
  238. codegen.append_output(codegen.create_load_closure(self.name))
  239. # closure variable created by an inlined function
  240. class InlinedClosureVariable(UnknownVariable):
  241. _nonvar_fields = {
  242. "name",
  243. *UnknownVariable._nonvar_fields,
  244. }
  245. def __init__(self, name, **kwargs):
  246. super().__init__(**kwargs)
  247. self.name = name
  248. def reconstruct(self, codegen):
  249. codegen.append_output(codegen.create_load_closure(self.name))
  250. class NewCellVariable(VariableTracker):
  251. def __init__(self, **kwargs):
  252. super().__init__(**kwargs)
  253. class NewGlobalVariable(VariableTracker):
  254. def __init__(self, **kwargs):
  255. super().__init__(**kwargs)
  256. class InspectSignatureVariable(VariableTracker):
  257. """represents inspect.signature(...)"""
  258. @staticmethod
  259. def create(callable, **kwargs):
  260. if kwargs:
  261. unimplemented(f"inspect.signature with {kwargs}")
  262. return InspectSignatureVariable(callable)
  263. def __init__(self, inspected: VariableTracker, **kwargs):
  264. super().__init__(**kwargs)
  265. self.inspected = inspected
  266. def var_getattr(self, tx, name: str) -> "VariableTracker":
  267. if name == "parameters":
  268. return variables.ConstDictVariable(
  269. {
  270. variables.ConstantVariable.create(name): InspectParameterVariable()
  271. for name in self.inspected.inspect_parameter_names()
  272. },
  273. user_cls=dict,
  274. )
  275. return super().var_getattr(tx, name)
  276. class InspectParameterVariable(VariableTracker):
  277. """This is not implemented, if used will graph break."""
  278. pass
  279. def produce_trampoline_autograd_apply(fn_cls):
  280. def trampoline_autograd_apply(*args, **kwargs):
  281. return fn_cls.apply(*args, **kwargs)
  282. trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
  283. return trampoline_autograd_apply
  284. class AutogradFunctionVariable(VariableTracker):
  285. """represents a torch.autograd.Function subclass"""
  286. _nonvar_fields = {
  287. "fn_cls",
  288. *VariableTracker._nonvar_fields,
  289. }
  290. def __init__(self, fn_cls, **kwargs):
  291. super().__init__(**kwargs)
  292. self.fn_cls = fn_cls
  293. def call_apply(self, tx, args, kwargs):
  294. requires_grad = False
  295. def visit(node):
  296. nonlocal requires_grad
  297. if isinstance(node, variables.TensorVariable):
  298. if node.requires_grad is not False:
  299. requires_grad = True
  300. if isinstance(node, variables.NNModuleVariable):
  301. if node.is_training(tx):
  302. requires_grad = True
  303. VariableTracker.visit(visit, (args, kwargs))
  304. if (
  305. requires_grad
  306. and torch.is_grad_enabled()
  307. and config.capture_autograd_function
  308. ):
  309. from torch._functorch.autograd_function import (
  310. autograd_function_forward_rewritten,
  311. )
  312. from torch.autograd.function import _is_setup_context_defined
  313. forward_fn = self.fn_cls.forward
  314. is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
  315. if is_setup_ctx_defined:
  316. # If setup_context is defined, we generate a new forward function which includes
  317. # the original forward and setup_context function, and trace the new forward function.
  318. forward_fn = autograd_function_forward_rewritten(
  319. self.fn_cls.forward, self.fn_cls.setup_context
  320. )
  321. vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
  322. if vjp_fn is not torch.autograd.Function.vjp:
  323. unimplemented("NYI - User defind vjp")
  324. jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
  325. if jvp_fn is not torch.autograd.Function.jvp:
  326. unimplemented("NYI - User defind jvp")
  327. from .higher_order_ops import AutogradFunctionApplyVariable
  328. source = self.source
  329. if source is None:
  330. source = AttrSource(
  331. tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
  332. )
  333. val = AutogradFunctionApplyVariable(
  334. forward_fn,
  335. self.fn_cls.backward,
  336. source,
  337. source=AttrSource(source, member="apply"),
  338. ).call_function(tx, args, kwargs)
  339. # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
  340. # the forward function, as we don't want to generate guards for new_forward.__closure__
  341. # if forward is rewritten by autograd_function_forward_rewritten.
  342. # But we still need to generate correct guards for the original forward and setup_context
  343. # functions, so we have to add guards manually.
  344. if self.source:
  345. fwd_src = AttrSource(self.source, "forward")
  346. install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
  347. if is_setup_ctx_defined:
  348. setup_ctx_src = AttrSource(self.source, "setup_context")
  349. install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
  350. return val
  351. if self.source:
  352. source = AttrSource(self.source, "forward")
  353. else:
  354. source = None
  355. fn = self.fn_cls.forward
  356. ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
  357. args = [ctx, *args]
  358. if isinstance(fn, types.FunctionType):
  359. return variables.UserFunctionVariable(fn, source=source).call_function(
  360. tx, args, kwargs
  361. )
  362. elif isinstance(fn, types.MethodType):
  363. return variables.UserMethodVariable(
  364. fn.__func__,
  365. variables.UserDefinedClassVariable(self.fn_cls),
  366. source=source,
  367. ).call_function(tx, args, kwargs)
  368. else:
  369. unimplemented(
  370. f"non-function or method in subclass of torch.autograd.Function: {fn}"
  371. )
  372. def call_backward(self, tx, args, kwargs):
  373. fn = self.fn_cls.backward
  374. self.source = AttrSource(self.source, "backward")
  375. assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
  376. assert isinstance(fn, types.FunctionType)
  377. return variables.UserFunctionVariable(fn, source=self.source).call_function(
  378. tx, args, kwargs
  379. )
  380. def call_function(self, tx, args, kwargs):
  381. return AutogradFunctionVariable(self.fn_cls)
  382. def call_method(
  383. self,
  384. tx,
  385. name,
  386. args: "List[VariableTracker]",
  387. kwargs: "Dict[str, VariableTracker]",
  388. ):
  389. from ..trace_rules import is_callable_allowed
  390. from .builder import wrap_fx_proxy
  391. if name == "apply":
  392. if is_callable_allowed(self.fn_cls):
  393. trampoline_autograd_apply = produce_trampoline_autograd_apply(
  394. self.fn_cls
  395. )
  396. return wrap_fx_proxy(
  397. tx=tx,
  398. proxy=tx.output.create_proxy(
  399. "call_function",
  400. trampoline_autograd_apply,
  401. *proxy_args_kwargs(args, kwargs),
  402. ),
  403. )
  404. else:
  405. return self.call_apply(tx, args, kwargs)
  406. elif name == "backward":
  407. return self.call_backward(tx, args, kwargs)
  408. else:
  409. from .. import trace_rules
  410. source = AttrSource(self.source, name) if self.source is not None else None
  411. try:
  412. obj = inspect.getattr_static(self.fn_cls, name)
  413. except AttributeError:
  414. obj = None
  415. if isinstance(obj, staticmethod):
  416. func = obj.__get__(self.fn_cls)
  417. if source is not None:
  418. return (
  419. trace_rules.lookup(func)
  420. .create_with_source(func, source=source)
  421. .call_function(tx, args, kwargs)
  422. )
  423. else:
  424. return trace_rules.lookup(func)(func).call_function(
  425. tx, args, kwargs
  426. )
  427. elif isinstance(obj, classmethod):
  428. return variables.UserMethodVariable(
  429. obj.__func__, self, source=source
  430. ).call_function(tx, args, kwargs)
  431. else:
  432. unimplemented(f"Unsupported method: {name}")
  433. @dataclasses.dataclass
  434. class SavedTensorBox:
  435. tensors: List[VariableTracker] = dataclasses.field(default_factory=list)
  436. class AutogradFunctionContextVariable(UserDefinedObjectVariable):
  437. """
  438. Tracks an autograd.Function() context using mutation tracking in side_effects.py
  439. """
  440. _nonvar_fields = {
  441. "proxy",
  442. "inference",
  443. "saved_tensors",
  444. *UserDefinedObjectVariable._nonvar_fields,
  445. }
  446. def __init__(
  447. self,
  448. value,
  449. value_type=None,
  450. inference=False,
  451. proxy=None,
  452. saved_tensors=None,
  453. needs_input_grad=None,
  454. **kwargs,
  455. ):
  456. super().__init__(value=value, value_type=value_type, **kwargs)
  457. self.inference = inference
  458. self.proxy = proxy
  459. self.saved_tensors = saved_tensors
  460. self.needs_input_grad = needs_input_grad
  461. @staticmethod
  462. def create(tx, args=None, kwargs=None):
  463. needs_input_grad = None
  464. if args and not kwargs:
  465. needs_input_grad = tuple(
  466. isinstance(x, variables.TensorVariable) and x.requires_grad
  467. for x in args
  468. )
  469. proxy = tx.output.create_proxy(
  470. "call_function", torch.autograd.function.FunctionCtx, tuple(), {}
  471. )
  472. out = tx.output.side_effects.track_object_new(
  473. None,
  474. torch.autograd.function.FunctionCtx,
  475. functools.partial(
  476. AutogradFunctionContextVariable,
  477. inference=True,
  478. proxy=proxy,
  479. saved_tensors=SavedTensorBox(),
  480. needs_input_grad=needs_input_grad,
  481. ),
  482. {},
  483. )
  484. set_example_value(proxy.node, out.value)
  485. return out
  486. def as_proxy(self):
  487. if self.proxy is None:
  488. unimplemented("proxy not set")
  489. return self.proxy
  490. def call_method(
  491. self,
  492. tx,
  493. name,
  494. args: "List[VariableTracker]",
  495. kwargs: "Dict[str, VariableTracker]",
  496. ) -> "VariableTracker":
  497. if name == "__setattr__":
  498. return super().call_method(tx, name, args, kwargs)
  499. if name != "save_for_backward":
  500. unimplemented(f"autograd.Function context method: {name}")
  501. if self.saved_tensors is None:
  502. unimplemented(
  503. "save_for_backward only supported on a newly constructed FunctionCtx"
  504. )
  505. if not self.inference:
  506. assert self.source and not kwargs
  507. tx.output.side_effects.track_save_for_backward(self, args)
  508. # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
  509. if len(self.saved_tensors.tensors) > 0:
  510. self.saved_tensors.tensors = []
  511. for arg in args:
  512. self.saved_tensors.tensors.append(arg)
  513. return variables.ConstantVariable.create(None)
  514. def var_getattr(self, tx, name):
  515. if name == "save_for_backward":
  516. return LambdaVariable(
  517. lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
  518. )
  519. if name == "saved_tensors" and self.saved_tensors is not None:
  520. return variables.TupleVariable(list(self.saved_tensors.tensors))
  521. if name == "needs_input_grad":
  522. if self.needs_input_grad is not None:
  523. return variables.ConstantVariable.create(self.needs_input_grad)
  524. if self.source:
  525. from .builder import VariableBuilder
  526. return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
  527. self.value.needs_input_grad
  528. )
  529. return super().var_getattr(tx, name)
  530. class LambdaVariable(VariableTracker):
  531. def __init__(self, fn, **kwargs):
  532. super().__init__(**kwargs)
  533. self.fn = fn
  534. def call_function(
  535. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  536. ) -> "VariableTracker":
  537. return self.fn(*args, **kwargs)
  538. class GetAttrVariable(VariableTracker):
  539. _nonvar_fields = {
  540. "name",
  541. *VariableTracker._nonvar_fields,
  542. }
  543. def __init__(self, obj, name, **kwargs):
  544. super().__init__(**kwargs)
  545. assert isinstance(obj, VariableTracker)
  546. assert isinstance(name, str)
  547. self.obj = obj
  548. self.name = name
  549. def __str__(self):
  550. return f"{self.__class__.__name__}({self.obj}, {self.name})"
  551. @staticmethod
  552. def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
  553. return getattr(base_proxy, attr)
  554. def as_proxy(self):
  555. return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
  556. def const_getattr(self, tx, name):
  557. if not isinstance(self.obj, variables.NNModuleVariable):
  558. raise NotImplementedError
  559. step1 = tx.output.get_submodule(self.obj.module_key)
  560. if self.name not in step1.__dict__:
  561. raise NotImplementedError
  562. step2 = inspect.getattr_static(step1, self.name)
  563. if name not in step2.__dict__:
  564. raise NotImplementedError
  565. return inspect.getattr_static(step2, name)
  566. def reconstruct(self, codegen):
  567. codegen(self.obj)
  568. codegen.extend_output(codegen.create_load_attrs(self.name))
  569. def call_function(
  570. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  571. ) -> "VariableTracker":
  572. return self.obj.call_method(tx, self.name, args, kwargs)
  573. def call_method(
  574. self,
  575. tx,
  576. name,
  577. args: List[VariableTracker],
  578. kwargs: Dict[str, VariableTracker],
  579. ) -> VariableTracker:
  580. if (
  581. name in ("__getitem__", "get")
  582. and self.name == "__dict__"
  583. and not kwargs
  584. and args[0].is_python_constant()
  585. and isinstance(
  586. self.obj,
  587. (variables.UserDefinedObjectVariable, variables.NNModuleVariable),
  588. )
  589. ):
  590. obj = self.obj
  591. key = args[0].as_python_constant()
  592. if obj.has_key_in_generic_dict(tx, key):
  593. # redirect to var_getattr on the original obj
  594. return obj.var_getattr(tx, key)
  595. # Return the default value for get
  596. if name == "get":
  597. if len(args) == 2:
  598. return args[1]
  599. else:
  600. return variables.ConstantVariable(None)
  601. elif (
  602. name == "__contains__"
  603. and self.name == "__dict__"
  604. and len(args) == 1
  605. and args[0].is_python_constant()
  606. and not kwargs
  607. and isinstance(
  608. self.obj,
  609. (variables.UserDefinedObjectVariable, variables.NNModuleVariable),
  610. )
  611. ):
  612. obj = self.obj
  613. key = args[0].as_python_constant()
  614. if obj.has_key_in_generic_dict(tx, key):
  615. return variables.ConstantVariable(True)
  616. else:
  617. return variables.ConstantVariable(False)
  618. return super().call_method(tx, name, args, kwargs)
  619. class MethodWrapperVariable(VariableTracker):
  620. def __init__(self, method_wrapper, **kwargs):
  621. super().__init__(**kwargs)
  622. self.method_wrapper = method_wrapper
  623. def call_function(
  624. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  625. ) -> "VariableTracker":
  626. if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
  627. args[0], variables.TensorVariable
  628. ):
  629. assert len(args) == 1 and len(kwargs) == 0
  630. return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
  631. super().call_function(tx, args, kwargs)
  632. def is_python_constant(self):
  633. return True
  634. def as_python_constant(self):
  635. return self.method_wrapper
  636. class GetSetDescriptorVariable(VariableTracker):
  637. def __init__(self, desc, **kwargs):
  638. super().__init__(**kwargs)
  639. self.desc = desc
  640. def var_getattr(self, tx, name):
  641. if name == "__get__" and self.source:
  642. from .builder import VariableBuilder
  643. return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
  644. self.desc.__get__
  645. )
  646. else:
  647. return super().var_getattr(tx, name)
  648. def is_python_constant(self):
  649. return True
  650. def as_python_constant(self):
  651. return self.desc
  652. class PythonModuleVariable(VariableTracker):
  653. _nonvar_fields = {
  654. "value",
  655. "is_torch",
  656. *VariableTracker._nonvar_fields,
  657. }
  658. def __init__(self, value: types.ModuleType, **kwargs):
  659. super().__init__(**kwargs)
  660. self.value = value
  661. self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
  662. def python_type(self):
  663. return types.ModuleType
  664. def as_python_constant(self):
  665. return self.value
  666. def __repr__(self):
  667. return f"PythonModuleVariable({self.value})"
  668. def call_hasattr(self, tx, name):
  669. if self.is_torch:
  670. result = hasattr(self.value, name)
  671. return variables.ConstantVariable.create(result)
  672. return super().call_hasattr(tx, name)
  673. class TypingVariable(VariableTracker):
  674. def __init__(self, value, **kwargs):
  675. super().__init__(**kwargs)
  676. self.value = value
  677. def call_method(
  678. self,
  679. tx,
  680. name,
  681. args: "List[VariableTracker]",
  682. kwargs: "Dict[str, VariableTracker]",
  683. ) -> "VariableTracker":
  684. if name == "__getitem__" and len(args) == 1:
  685. return variables.ConstantVariable.create(
  686. self.value[args[0].as_python_constant()],
  687. )
  688. unimplemented("typing")
  689. def python_type(self):
  690. return type(self.value)
  691. def as_python_constant(self):
  692. return self.value
  693. @functools.lru_cache(maxsize=1)
  694. def get_np_to_tnp_map():
  695. from ..utils import NP_TO_TNP_MODULE
  696. np_fn_to_tnp_fn = {}
  697. for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
  698. for fn_name, tnp_fn in tnp_mod.__dict__.items():
  699. if callable(tnp_fn):
  700. # some internal details do leak from tnp
  701. # which are not part of numpy API.
  702. if np_fn := getattr(np_mod, fn_name, None):
  703. np_fn_to_tnp_fn[np_fn] = tnp_fn
  704. return np_fn_to_tnp_fn
  705. class NumpyVariable(VariableTracker):
  706. """
  707. Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
  708. """
  709. constant_fold_functions = (tnp.issubdtype,)
  710. def __init__(self, value, **kwargs):
  711. super().__init__(**kwargs)
  712. self.value = value
  713. @classmethod
  714. def can_constant_fold_through(cls, fn):
  715. mod = fn.__module__.split(".")
  716. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  717. return fn in cls.constant_fold_functions
  718. @classmethod
  719. def get_constant_collection_for_func(cls, fn):
  720. mod = fn.__module__.split(".")
  721. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  722. return np_constant_collections_map.get(fn, None)
  723. def call_function(
  724. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  725. ) -> "VariableTracker":
  726. if not config.trace_numpy:
  727. unimplemented(f"numpy.{self.value}()")
  728. from ..utils import numpy_to_tensor_wrapper
  729. from .tensor import NumpyNdarrayVariable
  730. func = get_np_to_tnp_map().get(self.value)
  731. if func is None:
  732. unimplemented(
  733. f"Can't find numpy function {self.value} in torch._numpy. "
  734. " Please file an issue to request support for this function."
  735. )
  736. # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
  737. if (
  738. collection_variable_typ := self.get_constant_collection_for_func(func)
  739. ) is not None:
  740. try:
  741. return collection_variable_typ(
  742. self.value(
  743. *[x.as_python_constant() for x in args],
  744. **{k: v.as_python_constant() for k, v in kwargs.items()},
  745. )
  746. )
  747. except NotImplementedError:
  748. unimplemented(
  749. f"{self.value.__name__} with non-const args: {args} {kwargs}"
  750. )
  751. else:
  752. if (
  753. func.__module__ == "torch._numpy.random"
  754. and config.use_numpy_random_stream
  755. ):
  756. msg = f"delegate '{func.__qualname__}' to NumPy itself via "
  757. msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
  758. unimplemented(msg)
  759. args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
  760. if self.can_constant_fold_through(func) and (
  761. check_unspec_or_constant_args(args, kwargs)
  762. ):
  763. # constant fold
  764. return variables.ConstantVariable.create(
  765. self.as_python_constant()(
  766. *[x.as_python_constant() for x in args],
  767. **{k: v.as_python_constant() for k, v in kwargs.items()},
  768. ),
  769. )
  770. # TODO Add all the functions that go from constants to constants to can_constant_fold_through
  771. proxy = tx.output.create_proxy(
  772. "call_function",
  773. numpy_to_tensor_wrapper(func),
  774. *proxy_args_kwargs(args, kwargs),
  775. )
  776. return NumpyNdarrayVariable.create(tx, proxy)
  777. def call_method(
  778. self,
  779. tx,
  780. name,
  781. args: "List[VariableTracker]",
  782. kwargs: "Dict[str, VariableTracker]",
  783. ) -> "VariableTracker":
  784. unimplemented("numpy")
  785. def python_type(self):
  786. return type(self.value)
  787. def as_python_constant(self):
  788. return self.value
  789. def as_proxy(self):
  790. if config.trace_numpy and isinstance(self.value, type):
  791. # This handles numpy dtype attributes such as np.float32
  792. # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
  793. # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
  794. return self.value.__name__
  795. return super().as_proxy()
  796. # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
  797. class NullVariable(VariableTracker):
  798. def __init__(self, **kwargs):
  799. super().__init__(**kwargs)
  800. def __str__(self):
  801. return "NullVariable"
  802. def reconstruct(self, codegen):
  803. if sys.version_info < (3, 11):
  804. unimplemented("cannot reconstruct NullVariable in < Python 3.11")
  805. codegen.append_output(create_instruction("PUSH_NULL"))
  806. class DeletedVariable(VariableTracker):
  807. """Marker used to implement delattr()"""
  808. class StringFormatVariable(VariableTracker):
  809. """
  810. Represents a call to str.format(), we delay calling format until after the graph.
  811. """
  812. _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
  813. @classmethod
  814. def create(cls, format_string, sym_args, sym_kwargs):
  815. if all(
  816. x.is_python_constant()
  817. for x in itertools.chain(sym_args, sym_kwargs.values())
  818. ):
  819. return variables.ConstantVariable.create(
  820. format_string.format(
  821. *[v.as_python_constant() for v in sym_args],
  822. **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
  823. )
  824. )
  825. return cls(format_string, list(sym_args), dict(sym_kwargs))
  826. def __init__(self, format_string, sym_args, sym_kwargs, **kwargs):
  827. super().__init__(**kwargs)
  828. assert isinstance(format_string, str)
  829. self.format_string = format_string
  830. self.sym_args = sym_args
  831. self.sym_kwargs = sym_kwargs
  832. def __repr__(self):
  833. return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
  834. def reconstruct(self, codegen):
  835. if sys.version_info >= (3, 11):
  836. codegen.append_output(create_instruction("PUSH_NULL"))
  837. codegen.append_output(codegen.create_load_const(self.format_string))
  838. codegen.append_output(codegen.create_load_attr("format"))
  839. codegen(variables.TupleVariable(self.sym_args))
  840. kwargs = {
  841. variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
  842. }
  843. codegen(variables.ConstDictVariable(kwargs))
  844. codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
  845. class DebuggingVariable(VariableTracker):
  846. """
  847. Represents a call to a debugging function like print(), or something
  848. registered to config.reorderable_logging_functions.
  849. """
  850. def __init__(self, value, **kwargs):
  851. super().__init__(**kwargs)
  852. self.value = value
  853. @staticmethod
  854. def is_reorderable_logging_function(obj):
  855. return (
  856. callable(obj)
  857. and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
  858. and obj in torch._dynamo.config.reorderable_logging_functions
  859. )
  860. def call_function(self, tx, args, kwargs):
  861. if tx.export:
  862. # For export cases, we can just make debugging functions no-ops
  863. return
  864. if not self.can_reorder_logs(self.value, args, kwargs):
  865. unimplemented(
  866. f"Reordering debugging function {self.value} "
  867. f"with inputs {args} {kwargs} is not yet implemented."
  868. )
  869. tx.debug_locals.append((self, list(args)))
  870. def reconstruct(self, codegen):
  871. return self.source.reconstruct(codegen)
  872. @staticmethod
  873. def can_reorder_logs(fn, args, kwargs) -> True:
  874. """
  875. Run some additional checks for what sort of function calls can we
  876. actually reorder.
  877. """
  878. allowed_input_types = (
  879. variables.TensorVariable,
  880. variables.ConstantVariable,
  881. StringFormatVariable,
  882. )
  883. flat_args = pytree.tree_leaves([args, kwargs])
  884. for arg in flat_args:
  885. if not isinstance(arg, allowed_input_types):
  886. return False
  887. return True
  888. class LoggingLoggerVariable(VariableTracker):
  889. """
  890. Represents a call to any of logging.Logger methods
  891. """
  892. def __init__(self, value, **kwargs):
  893. super().__init__(**kwargs)
  894. def call_method(
  895. self,
  896. tx,
  897. name,
  898. args: "List[VariableTracker]",
  899. kwargs: "Dict[str, VariableTracker]",
  900. ) -> "VariableTracker":
  901. if tx.export:
  902. # For export cases, we can just make debugging functions no-ops
  903. return
  904. unimplemented("Logger not supported for non-export cases")
  905. class StopIterationVariable(VariableTracker):
  906. def __init__(self, args, **kwargs):
  907. super().__init__(**kwargs)
  908. self.args = args
  909. def reconstruct(self, codegen):
  910. codegen.load_import_from("builtins", "StopIteration")
  911. codegen.foreach(self.args)
  912. codegen.call_function(len(self.args), True)
  913. class ConstantLikeVariable(VariableTracker):
  914. """self.value is a compile-time constant, but not a literal"""
  915. _error_prefix = "ConstantLikeVariable"
  916. try:
  917. from numpy import (
  918. dtype as np_dtype,
  919. floating as np_floating,
  920. generic as np_generic,
  921. )
  922. except ImportError:
  923. np_floating = type("invalid_type", (), {})
  924. np_dtype = type("invalid_type", (), {})
  925. def __init__(self, value, **kwargs):
  926. super().__init__(**kwargs)
  927. self.value = value
  928. def python_type(self):
  929. return type(self.value)
  930. def as_python_constant(self):
  931. return self.value
  932. def call_method(
  933. self,
  934. tx,
  935. name,
  936. args: List[VariableTracker],
  937. kwargs: Dict[str, VariableTracker],
  938. ) -> VariableTracker:
  939. try:
  940. # we only support constant propagation for methods
  941. cargs = [x.as_python_constant() for x in args]
  942. ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  943. except NotImplementedError:
  944. unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
  945. result = getattr(self.value, name)(*cargs, **ckwargs)
  946. if variables.ConstantVariable.is_literal(result):
  947. return variables.ConstantVariable.create(result)
  948. if isinstance(result, re.Match):
  949. return ConstantRegexMatchVariable(result)
  950. unimplemented(f"{self._error_prefix}.{name}() -> {result}")
  951. def var_getattr(self, tx, name: str) -> VariableTracker:
  952. result = getattr(self.value, name)
  953. if isinstance(result, self.np_floating):
  954. result = float(result)
  955. if isinstance(result, self.np_dtype):
  956. return NumpyDTypeVariable(result)
  957. if isinstance(result, type) and issubclass(result, self.np_generic):
  958. # things like x.dtype.type
  959. return NumpyVariable(result)
  960. if variables.ConstantVariable.is_literal(result):
  961. return variables.ConstantVariable.create(result)
  962. return GetAttrVariable(self, name)
  963. class RegexPatternVariable(ConstantLikeVariable):
  964. _error_prefix = "re.Pattern"
  965. class ConstantRegexMatchVariable(ConstantLikeVariable):
  966. _error_prefix = "re.Match"
  967. class TorchVersionVariable(ConstantLikeVariable):
  968. _error_prefix = "torch.__version__"
  969. def __init__(self, **kwargs):
  970. kwargs.setdefault("value", torch.__version__)
  971. assert kwargs["value"] is torch.__version__
  972. super().__init__(**kwargs)
  973. class NumpyTypeInfoVariable(ConstantLikeVariable):
  974. _error_prefix = "np.iinfo/np.finfo"
  975. class NumpyDTypeVariable(ConstantLikeVariable):
  976. _error_prefix = "np.dtype[...]"
  977. def as_proxy(self):
  978. """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
  979. np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
  980. This also handles unsupported things nicely (i.e. structured arrays and object arrays).
  981. """
  982. return self.value.type.__name__
  983. np_constant_collections_map = {
  984. tnp.finfo: NumpyTypeInfoVariable,
  985. tnp.iinfo: NumpyTypeInfoVariable,
  986. tnp.dtype: NumpyDTypeVariable,
  987. }