builtin.py 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985
  1. # mypy: ignore-errors
  2. import contextlib
  3. import functools
  4. import inspect
  5. import itertools
  6. import logging
  7. import math
  8. import operator
  9. import types
  10. from collections import defaultdict, OrderedDict
  11. from typing import Dict, List
  12. import torch
  13. from torch import sym_float, sym_int
  14. from .. import config, polyfill, variables
  15. from ..exc import (
  16. AttributeMutationError,
  17. unimplemented,
  18. Unsupported,
  19. UserError,
  20. UserErrorType,
  21. )
  22. from ..guards import GuardBuilder, install_guard
  23. from ..replay_record import DummyModule
  24. from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
  25. from ..utils import (
  26. check_constant_args,
  27. check_numpy_ndarray_args,
  28. check_unspec_or_constant_args,
  29. check_unspec_python_args,
  30. extract_fake_example_value,
  31. get_fake_value,
  32. guard_if_dyn,
  33. istype,
  34. numpy_operator_wrapper,
  35. proxy_args_kwargs,
  36. tensortype_to_dtype,
  37. )
  38. from .base import MutableLocal, VariableTracker
  39. from .constant import ConstantVariable
  40. from .ctx_manager import EventVariable, StreamVariable
  41. from .dicts import (
  42. ConstDictVariable,
  43. DefaultDictVariable,
  44. DictView,
  45. is_hashable,
  46. SetVariable,
  47. )
  48. from .lists import (
  49. BaseListVariable,
  50. ListIteratorVariable,
  51. ListVariable,
  52. SizeVariable,
  53. TupleIteratorVariable,
  54. TupleVariable,
  55. )
  56. from .tensor import (
  57. FakeItemVariable,
  58. supported_comparison_ops,
  59. SymNodeVariable,
  60. TensorVariable,
  61. UnspecializedPythonVariable,
  62. )
  63. from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
  64. log = logging.getLogger(__name__)
  65. IN_PLACE_DESUGARING_MAP = {
  66. operator.iadd: operator.add,
  67. operator.isub: operator.sub,
  68. operator.imul: operator.mul,
  69. operator.ifloordiv: operator.floordiv,
  70. operator.itruediv: operator.truediv,
  71. operator.imod: operator.mod,
  72. operator.imatmul: operator.imatmul,
  73. operator.ilshift: operator.lshift,
  74. operator.irshift: operator.rshift,
  75. operator.ipow: operator.pow,
  76. operator.iand: operator.and_,
  77. operator.ior: operator.or_,
  78. operator.ixor: operator.xor,
  79. }
  80. def _polyfill_call_impl(name):
  81. """Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}"""
  82. def call_fn(self, tx, *args, **kwargs):
  83. return tx.inline_user_function_return(
  84. variables.UserFunctionVariable(fn), args, kwargs
  85. )
  86. fn = getattr(polyfill, name)
  87. call_fn.__name__ = f"call_{name}"
  88. return call_fn
  89. class BuiltinVariable(VariableTracker):
  90. _SENTINEL = object()
  91. _nonvar_fields = {
  92. "fn",
  93. *VariableTracker._nonvar_fields,
  94. }
  95. @classmethod
  96. def create_with_source(cls, value, source):
  97. install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
  98. return BuiltinVariable(value, source=source)
  99. @staticmethod
  100. @functools.lru_cache(None)
  101. def _constant_fold_functions():
  102. fns = {
  103. abs,
  104. all,
  105. any,
  106. bool,
  107. callable,
  108. chr,
  109. divmod,
  110. float,
  111. getattr,
  112. int,
  113. len,
  114. max,
  115. min,
  116. ord,
  117. pow,
  118. repr,
  119. round,
  120. str,
  121. str.format,
  122. sum,
  123. type,
  124. operator.abs,
  125. operator.pos,
  126. operator.neg,
  127. operator.not_,
  128. operator.truth,
  129. operator.invert,
  130. operator.pow,
  131. operator.mul,
  132. operator.matmul,
  133. operator.floordiv,
  134. operator.truediv,
  135. operator.mod,
  136. operator.add,
  137. operator.sub,
  138. operator.getitem,
  139. operator.length_hint,
  140. operator.lshift,
  141. operator.rshift,
  142. operator.and_,
  143. operator.or_,
  144. operator.xor,
  145. operator.ipow,
  146. operator.imul,
  147. operator.imatmul,
  148. operator.ifloordiv,
  149. operator.itruediv,
  150. operator.imod,
  151. operator.iadd,
  152. operator.isub,
  153. operator.ilshift,
  154. operator.irshift,
  155. operator.iand,
  156. operator.ixor,
  157. operator.ior,
  158. operator.index,
  159. }
  160. from .tensor import supported_comparison_ops
  161. fns.update(supported_comparison_ops.values())
  162. fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
  163. return fns
  164. def can_constant_fold_through(self):
  165. return self.fn in self._constant_fold_functions()
  166. @staticmethod
  167. @functools.lru_cache(None)
  168. def _fx_graph_functions():
  169. fns = {
  170. operator.abs,
  171. operator.pos,
  172. operator.neg,
  173. operator.not_,
  174. operator.invert,
  175. operator.pow,
  176. operator.mul,
  177. operator.matmul,
  178. operator.floordiv,
  179. operator.truediv,
  180. operator.mod,
  181. operator.add,
  182. operator.lt,
  183. operator.gt,
  184. operator.ge,
  185. operator.le,
  186. operator.ne,
  187. operator.eq,
  188. operator.sub,
  189. operator.getitem,
  190. operator.length_hint,
  191. operator.lshift,
  192. operator.rshift,
  193. operator.and_,
  194. operator.or_,
  195. operator.xor,
  196. operator.ipow,
  197. operator.imul,
  198. operator.imatmul,
  199. operator.ifloordiv,
  200. operator.itruediv,
  201. operator.imod,
  202. operator.iadd,
  203. operator.isub,
  204. operator.ilshift,
  205. operator.irshift,
  206. operator.iand,
  207. operator.ixor,
  208. operator.ior,
  209. }
  210. return fns
  211. @staticmethod
  212. @functools.lru_cache(None)
  213. def _binops():
  214. # function -> ([forward name, reverse name, in-place name], in-place op)
  215. fns = {
  216. operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
  217. operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
  218. operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
  219. operator.truediv: (
  220. ["__truediv__", "__rtruediv__", "__itruediv__"],
  221. operator.itruediv,
  222. ),
  223. operator.floordiv: (
  224. ["__floordiv__", "__rfloordiv__", "__ifloordiv__"],
  225. operator.ifloordiv,
  226. ),
  227. operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod),
  228. pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
  229. operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow),
  230. operator.lshift: (
  231. ["__lshift__", "__rlshift__", "__ilshift__"],
  232. operator.ilshift,
  233. ),
  234. operator.rshift: (
  235. ["__rshift__", "__rrshift__", "__irshift__"],
  236. operator.irshift,
  237. ),
  238. # NB: The follow binary operators are not supported for now, since the
  239. # corresponding magic methods aren't defined on SymInt / SymFloat:
  240. # operator.matmul
  241. # divmod
  242. # operator.and_
  243. # operator.or_
  244. # operator.xor
  245. }
  246. return fns
  247. @staticmethod
  248. @functools.lru_cache(None)
  249. def _binop_handlers():
  250. # Multiple dispatch mechanism defining custom binop behavior for certain type
  251. # combinations. Handlers are attempted in order, and will be used if the type checks
  252. # match. They are expected to have the signature:
  253. # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
  254. from .dicts import DictKeys, SetVariable
  255. from .functions import BaseUserFunctionVariable, UserFunctionVariable
  256. from .nn_module import NNModuleVariable
  257. from .tensor import supported_const_comparison_ops
  258. from .torch import BaseTorchVariable
  259. from .user_defined import (
  260. UserDefinedClassVariable,
  261. UserDefinedObjectVariable,
  262. UserDefinedVariable,
  263. )
  264. # Override table contains: op_fn -> [list of handlers]
  265. op_handlers = {}
  266. for (
  267. op,
  268. (magic_method_names, in_place_op),
  269. ) in BuiltinVariable._binops().items():
  270. op_handlers[op] = []
  271. op_handlers[in_place_op] = []
  272. forward_name, reverse_name, inplace_name = magic_method_names
  273. # User-defined args (highest precedence)
  274. def user_defined_handler(
  275. tx,
  276. a,
  277. b,
  278. *,
  279. forward_name=forward_name,
  280. reverse_name=reverse_name,
  281. ):
  282. # Manually handle reversing logic if needed (e.g. call __radd__)
  283. # TODO: If we expand this to handle tensor args, we need to manually
  284. # handle cases like this:
  285. #
  286. # class A(int):
  287. # def __radd__(self, other):
  288. # print("woof")
  289. # torch.randn(3) + A(3)
  290. #
  291. # In this example, A.__radd__() is not called -> nothing is printed, because
  292. # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
  293. # To be fully correct, we should not call A.__radd__() here, and there may be
  294. # other cases to reason about and add exceptions for.
  295. if isinstance(a, UserDefinedVariable):
  296. return a.call_method(tx, forward_name, [b], {})
  297. else:
  298. return b.call_method(tx, reverse_name, [a], {})
  299. op_handlers[op].append(
  300. ((UserDefinedVariable, VariableTracker), user_defined_handler)
  301. )
  302. op_handlers[op].append(
  303. ((VariableTracker, UserDefinedVariable), user_defined_handler)
  304. )
  305. def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
  306. return a.call_method(tx, forward_name, [b], {})
  307. op_handlers[in_place_op].append(
  308. ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler)
  309. )
  310. op_handlers[in_place_op].append(
  311. ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler)
  312. )
  313. # Dynamic shape args
  314. def dynamic_handler(tx, a, b, *, fn=op):
  315. from .builder import wrap_fx_proxy
  316. return wrap_fx_proxy(
  317. tx,
  318. tx.output.create_proxy(
  319. "call_function", fn, *proxy_args_kwargs([a, b], {})
  320. ),
  321. )
  322. op_handlers[op].append(
  323. ((SymNodeVariable, VariableTracker), dynamic_handler)
  324. )
  325. op_handlers[op].append(
  326. ((VariableTracker, SymNodeVariable), dynamic_handler)
  327. )
  328. # NB: Prefer out-of-place op when calling in-place op to generate valid graph
  329. op_handlers[in_place_op].append(
  330. ((SymNodeVariable, VariableTracker), dynamic_handler)
  331. )
  332. op_handlers[in_place_op].append(
  333. ((VariableTracker, SymNodeVariable), dynamic_handler)
  334. )
  335. # Special cases - lower precedence but still prefer these over constant folding
  336. # List-like addition (e.g. [1, 2] + [3, 4])
  337. def tuple_add_handler(tx, a, b):
  338. return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
  339. def size_add_handler(tx, a, b):
  340. return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
  341. list_like_addition_handlers = [
  342. # NB: Prefer the tuple-specific logic over base logic because of
  343. # some SizeVariable weirdness. Specifically, the tuple-specific logic
  344. # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
  345. (
  346. (SizeVariable, SizeVariable),
  347. size_add_handler,
  348. ),
  349. (
  350. (TupleVariable, TupleVariable),
  351. tuple_add_handler,
  352. ),
  353. (
  354. (TupleVariable, ConstantVariable),
  355. tuple_add_handler,
  356. ),
  357. (
  358. (ConstantVariable, TupleVariable),
  359. lambda tx, a, b: TupleVariable(
  360. [*a.unpack_var_sequence(tx), *b.items],
  361. ),
  362. ),
  363. (
  364. (
  365. ListVariable,
  366. (BaseListVariable, ConstantVariable, ListIteratorVariable),
  367. ),
  368. lambda tx, a, b: ListVariable(
  369. [*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
  370. ),
  371. ),
  372. (
  373. (BaseListVariable, BaseListVariable),
  374. lambda tx, a, b: type(a)([*a.items, *b.items]),
  375. ),
  376. ]
  377. op_handlers[operator.add].extend(list_like_addition_handlers)
  378. def list_iadd_handler(tx, a, b):
  379. if not a.mutable_local or not b.has_unpack_var_sequence(tx):
  380. # Handler doesn't apply
  381. return None
  382. seq = b.unpack_var_sequence(tx)
  383. tx.output.side_effects.mutation(a)
  384. a.items.extend(seq)
  385. return a
  386. list_like_iadd_handlers = [
  387. (
  388. (ListVariable, VariableTracker),
  389. list_iadd_handler,
  390. ),
  391. (
  392. (TupleVariable, TupleVariable),
  393. tuple_add_handler,
  394. ),
  395. (
  396. (TupleVariable, ConstantVariable),
  397. tuple_add_handler,
  398. ),
  399. ]
  400. op_handlers[operator.iadd].extend(list_like_iadd_handlers)
  401. # List-like expansion (e.g. [1, 2, 3] * 3)
  402. def expand_list_like(tx, lst, const):
  403. if isinstance(lst, ConstantVariable):
  404. lst, const = const, lst
  405. return lst.__class__(
  406. items=lst.items * const.as_python_constant(),
  407. mutable_local=MutableLocal(),
  408. )
  409. list_like_expansion_handlers = [
  410. ((ListVariable, ConstantVariable), expand_list_like),
  411. ((TupleVariable, ConstantVariable), expand_list_like),
  412. ((ConstantVariable, ListVariable), expand_list_like),
  413. ((ConstantVariable, TupleVariable), expand_list_like),
  414. ]
  415. op_handlers[operator.mul].extend(list_like_expansion_handlers)
  416. size_or_tuple = (SizeVariable, TupleVariable)
  417. has_set_items = (SetVariable, DictKeys)
  418. def create_cmp_op_handlers(op):
  419. def compare_by_value(tx, a, b):
  420. return ConstantVariable(op(a.value, b.value))
  421. result = [((ConstantVariable, ConstantVariable), compare_by_value)]
  422. if op in supported_const_comparison_ops.values():
  423. # Tensor is None, List is not None, etc
  424. none_result = op(object(), None)
  425. if op.__name__.startswith("is_"):
  426. def never(tx, a, b):
  427. return ConstantVariable(none_result)
  428. obj_op_none = never
  429. none_op_obj = never
  430. else:
  431. def obj_op_none(tx, a, b: ConstantVariable):
  432. if b.value is None or b.value is True or b.value is False:
  433. return ConstantVariable(none_result)
  434. def none_op_obj(tx, a: ConstantVariable, b):
  435. if a.value is None or a.value is True or a.value is False:
  436. return ConstantVariable(none_result)
  437. types_that_are_never_none = (
  438. TensorVariable,
  439. SymNodeVariable,
  440. NNModuleVariable,
  441. BaseListVariable,
  442. UserDefinedVariable,
  443. BaseUserFunctionVariable,
  444. ConstDictVariable,
  445. BaseTorchVariable,
  446. )
  447. result.extend(
  448. [
  449. (
  450. (types_that_are_never_none, ConstantVariable),
  451. obj_op_none,
  452. ),
  453. (
  454. (ConstantVariable, types_that_are_never_none),
  455. none_op_obj,
  456. ),
  457. ]
  458. )
  459. def list_compare_nocheck(tx, left, right):
  460. return BaseListVariable.list_compare(tx, op, left, right)
  461. def list_compare_check(tx, left, right):
  462. if type(left) is not type(
  463. right
  464. ): # Mismatch in BaseListVariable subclasses
  465. unimplemented(f"{op.__name__}({left}, {right})")
  466. return BaseListVariable.list_compare(tx, op, left, right)
  467. def compare_set_items(tx, left, right):
  468. return ConstantVariable(op(left.set_items, right.set_items))
  469. def compare_via_method(tx, left, right):
  470. return left.call_method(tx, f"__{op.__name__}__", [right], {})
  471. if op.__name__.startswith("is_"):
  472. compare_user_defined = compare_by_value
  473. else:
  474. compare_user_defined = compare_via_method
  475. op_var = BuiltinVariable(op)
  476. result.extend(
  477. [
  478. (
  479. (
  480. (UserFunctionVariable, BuiltinVariable),
  481. (UserFunctionVariable, BuiltinVariable),
  482. ),
  483. lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
  484. ),
  485. (
  486. (
  487. NNModuleVariable,
  488. NNModuleVariable,
  489. ),
  490. lambda tx, a, b: ConstantVariable(
  491. op(
  492. tx.output.get_submodule(a.module_key),
  493. tx.output.get_submodule(b.module_key),
  494. )
  495. ),
  496. ),
  497. ((size_or_tuple, size_or_tuple), list_compare_nocheck),
  498. (
  499. (variables.BaseListVariable, variables.BaseListVariable),
  500. list_compare_check,
  501. ),
  502. ((has_set_items, has_set_items), compare_set_items),
  503. (
  504. (UserDefinedObjectVariable, UserDefinedObjectVariable),
  505. compare_user_defined,
  506. ),
  507. (
  508. (UserDefinedClassVariable, UserDefinedClassVariable),
  509. compare_user_defined,
  510. ),
  511. (
  512. (
  513. (StreamVariable, EventVariable, ConstantVariable),
  514. (StreamVariable, EventVariable, ConstantVariable),
  515. ),
  516. compare_by_value,
  517. ),
  518. (
  519. (TensorVariable, VariableTracker),
  520. op_var._comparison_with_tensor,
  521. ),
  522. (
  523. (VariableTracker, TensorVariable),
  524. op_var._comparison_with_tensor,
  525. ),
  526. (
  527. (SymNodeVariable, VariableTracker),
  528. op_var._comparison_with_symnode,
  529. ),
  530. (
  531. (VariableTracker, SymNodeVariable),
  532. op_var._comparison_with_symnode,
  533. ),
  534. ]
  535. )
  536. if op.__name__.startswith("is_"):
  537. def handle_is(tx, left, right):
  538. # If the two objects are of different type, we can safely return False
  539. # and True for `is` and `is not`, respectively
  540. if type(left) is not type(right):
  541. return ConstantVariable.create(op.__name__ != "is_")
  542. result.append(((VariableTracker, VariableTracker), handle_is))
  543. return result
  544. for op in supported_comparison_ops.values():
  545. assert callable(op)
  546. assert op not in op_handlers
  547. op_handlers[op] = create_cmp_op_handlers(op)
  548. return op_handlers
  549. @staticmethod
  550. def _find_binop_handler(op, a_type, b_type):
  551. handlers = BuiltinVariable._binop_handlers().get(op)
  552. if handlers is None:
  553. return None
  554. matches = []
  555. for (type1, type2), handler in handlers:
  556. if issubclass(a_type, type1) and issubclass(b_type, type2):
  557. matches.append(handler)
  558. return matches
  559. def can_insert_in_graph(self):
  560. return self.fn in self._fx_graph_functions()
  561. def __init__(self, fn, **kwargs):
  562. super().__init__(**kwargs)
  563. self.fn = fn
  564. def __str__(self):
  565. if self.fn is None:
  566. name = "None"
  567. else:
  568. name = self.fn.__name__
  569. return f"{self.__class__.__name__}({name})"
  570. def python_type(self):
  571. return type(self.fn)
  572. def as_python_constant(self):
  573. return self.fn
  574. def as_proxy(self):
  575. DTYPE = {
  576. bool: torch.bool,
  577. int: torch.int64,
  578. float: torch.float64,
  579. }
  580. if self.fn in DTYPE:
  581. return DTYPE[self.fn]
  582. return super().as_proxy()
  583. def reconstruct(self, codegen):
  584. name = self.fn.__name__
  585. assert self.fn.__module__ == "builtins"
  586. assert name not in codegen.tx.f_globals, "shadowed global"
  587. codegen.append_output(codegen.create_load_global(name, False, add=True))
  588. def constant_args(self, *args, **kwargs):
  589. return check_constant_args(args, kwargs)
  590. def tensor_args(self, *args):
  591. any_tensor = False
  592. for arg in args:
  593. if isinstance(arg, variables.GetAttrVariable):
  594. return False
  595. any_tensor = any_tensor or isinstance(arg, variables.TensorVariable)
  596. return any_tensor
  597. def tensor_args_type(self, arg_types):
  598. any_tensor = False
  599. for arg_type in arg_types:
  600. if issubclass(arg_type, variables.GetAttrVariable):
  601. return False
  602. any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable)
  603. return any_tensor
  604. def python_and_tensor_constant_only(self, *args, **kwargs):
  605. tensor_args = []
  606. non_tensor_args = []
  607. for i in itertools.chain(args, kwargs.values()):
  608. if isinstance(i, variables.TensorVariable):
  609. tensor_args.append(i)
  610. else:
  611. non_tensor_args.append(i)
  612. return all(
  613. is_constant_source(t.source) if t.source is not None else False
  614. for t in tensor_args
  615. ) and self.constant_args(*non_tensor_args)
  616. @staticmethod
  617. def unwrap_unspec_args_kwargs(args, kwargs):
  618. return [x.as_python_constant() for x in args], {
  619. k: v.as_python_constant() for k, v in kwargs.items()
  620. }
  621. def has_constant_handler(self, args, kwargs):
  622. return self.can_constant_fold_through() and check_unspec_or_constant_args(
  623. args, kwargs
  624. )
  625. @staticmethod
  626. def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
  627. from .builder import SourcelessBuilder
  628. from .lazy import LazyVariableTracker
  629. obj = BuiltinVariable(fn)
  630. handlers = []
  631. if any(issubclass(t, LazyVariableTracker) for t in arg_types):
  632. return lambda tx, args, kwargs: obj.call_function(
  633. tx, [v.realize() for v in args], kwargs
  634. )
  635. if inspect.isclass(fn) and issubclass(fn, Exception):
  636. def create_exception_class_object(tx, args, kwargs):
  637. if fn is AssertionError and not all(
  638. isinstance(x, variables.ConstantVariable)
  639. and isinstance(x.value, str)
  640. for x in args
  641. ):
  642. unimplemented("assert with non-string message")
  643. return variables.ExceptionVariable(fn, args, **kwargs)
  644. return create_exception_class_object
  645. if obj.can_insert_in_graph() and not (
  646. fn is operator.getitem
  647. and not issubclass(arg_types[0], variables.TensorVariable)
  648. ):
  649. if obj.tensor_args_type(arg_types):
  650. return obj._handle_insert_op_in_graph
  651. elif has_kwargs:
  652. # need runtime check for kwargs
  653. handlers.append(obj._handle_insert_op_in_graph)
  654. # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
  655. # NB: Tensor args are handled above and not here
  656. if len(arg_types) == 2 and not has_kwargs:
  657. # Try to find a handler for the arg types; otherwise, fall through to constant handler
  658. binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types)
  659. if not binop_handlers:
  660. pass
  661. elif len(binop_handlers) == 1:
  662. (binop_handler,) = binop_handlers
  663. handlers.append(lambda tx, args, _: binop_handler(tx, *args))
  664. else:
  665. def call_binop_handlers(tx, args, _):
  666. for fn in binop_handlers:
  667. rv = fn(tx, *args)
  668. if rv:
  669. return rv
  670. handlers.append(call_binop_handlers)
  671. self_handler = getattr(obj, f"call_{fn.__name__}", None)
  672. if self_handler:
  673. def call_self_handler(tx, args, kwargs):
  674. try:
  675. result = self_handler(tx, *args, **kwargs)
  676. if result is not None:
  677. return result
  678. except TypeError:
  679. # Check if binding is bad. inspect signature bind is expensive.
  680. # So check only when handler call fails.
  681. try:
  682. inspect.signature(self_handler).bind(tx, *args, **kwargs)
  683. except TypeError as e:
  684. has_constant_handler = obj.has_constant_handler(args, kwargs)
  685. if not has_constant_handler:
  686. log.warning(
  687. "incorrect arg count %s %s and no constant handler",
  688. self_handler,
  689. e,
  690. )
  691. unimplemented(
  692. f"invalid handler args {self_handler} {args} {kwargs}"
  693. )
  694. else:
  695. raise
  696. except Unsupported as exc:
  697. has_constant_handler = obj.has_constant_handler(args, kwargs)
  698. if not has_constant_handler:
  699. raise
  700. # Actually, we will handle this just fine
  701. exc.remove_from_stats()
  702. handlers.append(call_self_handler)
  703. if obj.can_constant_fold_through():
  704. builder = SourcelessBuilder.create
  705. if (
  706. all(issubclass(x, ConstantVariable) for x in arg_types)
  707. and not has_kwargs
  708. ):
  709. def constant_fold_handler(tx, args, kwargs):
  710. # fast path
  711. try:
  712. res = fn(
  713. *[x.as_python_constant() for x in args],
  714. )
  715. except Exception as exc:
  716. unimplemented(f"constant fold exception: {repr(exc)}")
  717. return builder(tx, res)
  718. else:
  719. def constant_fold_handler(tx, args, kwargs):
  720. # path with a runtime check
  721. if check_unspec_or_constant_args(args, kwargs):
  722. try:
  723. res = fn(
  724. *[x.as_python_constant() for x in args],
  725. **{
  726. k: v.as_python_constant() for k, v in kwargs.items()
  727. },
  728. )
  729. except Exception as exc:
  730. unimplemented(f"constant fold exception: {repr(exc)}")
  731. return builder(tx, res)
  732. handlers.append(constant_fold_handler)
  733. error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}"
  734. if len(handlers) == 0:
  735. return lambda *args: unimplemented(error_msg)
  736. elif len(handlers) == 1:
  737. (handler,) = handlers
  738. def builtin_dipatch(tx, args, kwargs):
  739. rv = handler(tx, args, kwargs)
  740. if rv:
  741. return rv
  742. unimplemented(error_msg)
  743. else:
  744. def builtin_dipatch(tx, args, kwargs):
  745. for fn in handlers:
  746. rv = fn(tx, args, kwargs)
  747. if rv:
  748. return rv
  749. unimplemented(error_msg)
  750. return builtin_dipatch
  751. def _handle_insert_op_in_graph(self, tx, args, kwargs):
  752. from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
  753. if kwargs and not self.tensor_args(*args, *kwargs.values()):
  754. return
  755. fn = self.fn
  756. try:
  757. # Constant fold for constant tensor and python constants
  758. if self.python_and_tensor_constant_only(*args, **kwargs):
  759. from ..bytecode_transformation import unique_id
  760. from .functions import invoke_and_store_as_constant
  761. return invoke_and_store_as_constant(
  762. tx, fn, unique_id(fn.__name__), args, kwargs
  763. )
  764. if fn in IN_PLACE_DESUGARING_MAP and isinstance(
  765. args[0], variables.ConstantVariable
  766. ):
  767. # In-place operators like += usually mustate tensor
  768. # values, but in the edge case of immutable values they
  769. # re-bind the variable.
  770. #
  771. # The easiest way to keep the graph consistent in this
  772. # scenario is to de-sugar eagerly.
  773. fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
  774. if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
  775. # Standard indexing will force specialization due to
  776. # __index__. Rewrite as a regular torch op which will
  777. # trace fine
  778. fn, args = torch.select, [
  779. args[0],
  780. variables.ConstantVariable.create(0),
  781. args[1],
  782. ]
  783. # Interaction between ndarray and tensors:
  784. # We prefer the tensor op whenever there are tensors involved
  785. if check_numpy_ndarray_args(args, kwargs) and not any(
  786. type(arg) == variables.TensorVariable for arg in args
  787. ):
  788. proxy = tx.output.create_proxy(
  789. "call_function",
  790. numpy_operator_wrapper(fn),
  791. *proxy_args_kwargs(args, kwargs),
  792. )
  793. return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy)
  794. proxy = tx.output.create_proxy(
  795. "call_function",
  796. fn,
  797. *proxy_args_kwargs(args, kwargs),
  798. )
  799. if any(isinstance(arg, FakeItemVariable) for arg in args):
  800. return wrap_fx_proxy_cls(
  801. FakeItemVariable,
  802. tx,
  803. proxy,
  804. )
  805. elif check_unspec_python_args(args, kwargs):
  806. _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
  807. raw_value = fn(*_args, **_kwargs)
  808. need_unwrap = any(
  809. x.need_unwrap
  810. for x in itertools.chain(args, kwargs.values())
  811. if isinstance(x, variables.UnspecializedPythonVariable)
  812. )
  813. return wrap_fx_proxy_cls(
  814. UnspecializedPythonVariable,
  815. tx,
  816. proxy,
  817. raw_value=raw_value,
  818. need_unwrap=need_unwrap,
  819. )
  820. elif all(isinstance(x, SymNodeVariable) for x in args):
  821. return SymNodeVariable.create(tx, proxy, None)
  822. else:
  823. # Work around for vision_maskrcnn due to precision difference
  824. # specialize the dividend when float divide by tensor
  825. if fn is operator.truediv and isinstance(
  826. args[0], variables.UnspecializedPythonVariable
  827. ):
  828. args[0] = args[0].convert_to_constant(tx)
  829. return wrap_fx_proxy(tx, proxy)
  830. except NotImplementedError:
  831. unimplemented(f"partial tensor op: {self} {args} {kwargs}")
  832. call_function_handler_cache = {}
  833. def call_function(
  834. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  835. ) -> "VariableTracker":
  836. if kwargs:
  837. kwargs = {k: v.realize() for k, v in kwargs.items()}
  838. key = (self.fn, *(type(x) for x in args), True)
  839. else:
  840. key = (self.fn, *(type(x) for x in args))
  841. handler = self.call_function_handler_cache.get(key)
  842. if not handler:
  843. self.call_function_handler_cache[key] = handler = self._make_handler(
  844. self.fn, [type(x) for x in args], bool(kwargs)
  845. )
  846. return handler(tx, args, kwargs)
  847. def call_method(
  848. self,
  849. tx,
  850. name,
  851. args: "List[VariableTracker]",
  852. kwargs: "Dict[str, VariableTracker]",
  853. ) -> "VariableTracker":
  854. if self.fn == object and name == "__setattr__":
  855. assert len(args) == 3
  856. assert len(kwargs) == 0
  857. obj, name_var, val = args
  858. obj = obj.realize()
  859. if (
  860. isinstance(obj, UserDefinedObjectVariable)
  861. and tx.output.side_effects.is_attribute_mutation(obj)
  862. and name_var.is_python_constant()
  863. ):
  864. return obj.method_setattr_standard(tx, name_var, val)
  865. if self.fn == dict and name == "fromkeys":
  866. return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
  867. if self.fn == itertools.chain and name == "from_iterable":
  868. assert len(args) == 1
  869. assert len(kwargs) == 0
  870. obj = args[0]
  871. items = []
  872. for item in obj.unpack_var_sequence(tx):
  873. items.extend(item.unpack_var_sequence(tx))
  874. return variables.TupleVariable(items)
  875. return super().call_method(tx, name, args, kwargs)
  876. def _call_int_float(self, tx, arg):
  877. # Handle cases like int(torch.seed())
  878. # Also handle sym_float to sym_int cases
  879. if isinstance(arg, (SymNodeVariable, variables.TensorVariable)):
  880. if isinstance(arg, variables.TensorVariable):
  881. item = arg.call_method(tx, "item", [], {})
  882. else:
  883. item = arg
  884. fn_ = sym_int if self.fn is int else sym_float
  885. from torch._dynamo.variables.builder import wrap_fx_proxy
  886. return wrap_fx_proxy(
  887. tx=tx,
  888. proxy=tx.output.create_proxy(
  889. "call_function",
  890. fn_,
  891. (item.as_proxy(),),
  892. {},
  893. ),
  894. )
  895. call_int = _call_int_float
  896. call_float = _call_int_float
  897. def call_str(self, tx, arg):
  898. # Handle `str` on a user defined function
  899. if isinstance(arg, (variables.UserFunctionVariable)):
  900. return variables.ConstantVariable.create(value=str(arg.fn))
  901. def _call_min_max(self, tx, *args):
  902. if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
  903. # expand iterable
  904. items = args[0].unpack_var_sequence(tx)
  905. return self._call_min_max_seq(tx, items)
  906. elif len(args) == 2:
  907. return self._call_min_max_binary(tx, args[0], args[1])
  908. elif len(args) > 2:
  909. return self._call_min_max_seq(tx, args)
  910. def _call_min_max_seq(self, tx, items):
  911. assert len(items) > 0
  912. if len(items) == 1:
  913. return items[0]
  914. return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
  915. def _call_min_max_binary(self, tx, a, b):
  916. if self.tensor_args(a, b):
  917. if not isinstance(a, variables.TensorVariable):
  918. a, b = b, a
  919. assert isinstance(a, variables.TensorVariable)
  920. # result of an item call is a scalar convert to a tensor
  921. if isinstance(a, FakeItemVariable):
  922. a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function(
  923. tx, [a], {}
  924. )
  925. # Dynamic input does not get resolved, rather, gets stored as call_function
  926. if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
  927. from .builder import wrap_fx_proxy_cls
  928. return wrap_fx_proxy_cls(
  929. type(a),
  930. tx=tx,
  931. proxy=tx.output.create_proxy(
  932. "call_function",
  933. self.fn,
  934. *proxy_args_kwargs([a, b], {}),
  935. ),
  936. )
  937. # convert min/max to torch ops
  938. if b.is_python_constant():
  939. if isinstance(a, variables.NumpyNdarrayVariable):
  940. import numpy as np
  941. fn = variables.NumpyVariable(np.clip)
  942. else:
  943. fn = variables.TorchInGraphFunctionVariable(torch.clamp)
  944. kwargs = {"min": b} if (self.fn is max) else {"max": b}
  945. result = fn.call_function(tx, [a], kwargs)
  946. else:
  947. if isinstance(a, variables.NumpyNdarrayVariable):
  948. import numpy as np
  949. fn = {max: np.maximum, min: np.minimum}[self.fn]
  950. fn = variables.NumpyVariable(fn)
  951. else:
  952. fn = {max: torch.maximum, min: torch.minimum}[self.fn]
  953. fn = variables.TorchInGraphFunctionVariable(fn)
  954. result = fn.call_function(tx, [a, b], {})
  955. # return unspec if both a, b are unspec or const
  956. if all(
  957. isinstance(
  958. i,
  959. (
  960. variables.UnspecializedPythonVariable,
  961. variables.ConstantVariable,
  962. ),
  963. )
  964. for i in [a, b]
  965. ):
  966. if any(isinstance(val, FakeItemVariable) for val in [a, b]):
  967. return variables.FakeItemVariable.from_tensor_variable(result)
  968. if b.is_python_constant():
  969. raw_b = b.as_python_constant()
  970. else:
  971. raw_b = b.raw_value
  972. if self.fn is max:
  973. raw_res = max(a.raw_value, raw_b)
  974. else:
  975. raw_res = min(a.raw_value, raw_b)
  976. need_unwrap = any(
  977. x.need_unwrap
  978. for x in [a, b]
  979. if isinstance(x, variables.UnspecializedPythonVariable)
  980. )
  981. return variables.UnspecializedPythonVariable.from_tensor_variable(
  982. result, raw_res, need_unwrap
  983. )
  984. # otherwise return tensor
  985. else:
  986. return result
  987. elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
  988. fn = torch.sym_max if self.fn is max else torch.sym_min
  989. proxy = tx.output.create_proxy(
  990. "call_function", fn, *proxy_args_kwargs([a, b], {})
  991. )
  992. return SymNodeVariable.create(tx, proxy, None)
  993. call_min = _call_min_max
  994. call_max = _call_min_max
  995. def call_abs(self, tx, arg: "VariableTracker"):
  996. # Call arg.__abs__()
  997. abs_method = BuiltinVariable(getattr).call_function(
  998. tx, [arg, ConstantVariable.create("__abs__")], {}
  999. )
  1000. return abs_method.call_function(tx, [], {})
  1001. def call_pos(self, tx, arg: "VariableTracker"):
  1002. # Call arg.__pos__()
  1003. pos_method = BuiltinVariable(getattr).call_function(
  1004. tx, [arg, ConstantVariable.create("__pos__")], {}
  1005. )
  1006. return pos_method.call_function(tx, [], {})
  1007. def call_index(self, tx, arg: "VariableTracker"):
  1008. if isinstance(arg, variables.TensorVariable):
  1009. unimplemented("unsupported index(tensor)")
  1010. arg = guard_if_dyn(arg)
  1011. constant_value = operator.index(arg)
  1012. return variables.ConstantVariable.create(constant_value)
  1013. def call_round(self, tx, arg, *args, **kwargs):
  1014. # Call arg.__round__()
  1015. round_method = BuiltinVariable(getattr).call_function(
  1016. tx, [arg, ConstantVariable.create("__round__")], {}
  1017. )
  1018. return round_method.call_function(tx, args, kwargs)
  1019. def call_range(self, tx, *args):
  1020. if check_unspec_or_constant_args(args, {}):
  1021. return variables.RangeVariable(args)
  1022. elif self._dynamic_args(*args):
  1023. args = [
  1024. variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
  1025. ]
  1026. return variables.RangeVariable(args)
  1027. # None no-ops this handler and lets the driving function proceed
  1028. return None
  1029. def _dynamic_args(self, *args, **kwargs):
  1030. return any(isinstance(x, SymNodeVariable) for x in args) or any(
  1031. isinstance(x, SymNodeVariable) for x in kwargs.values()
  1032. )
  1033. def call_slice(self, tx, *args):
  1034. return variables.SliceVariable(args)
  1035. def _dyn_proxy(self, tx, *args, **kwargs):
  1036. from .builder import wrap_fx_proxy
  1037. return wrap_fx_proxy(
  1038. tx,
  1039. tx.output.create_proxy(
  1040. "call_function", self.fn, *proxy_args_kwargs(args, kwargs)
  1041. ),
  1042. )
  1043. def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
  1044. if self._dynamic_args(*args, **kwargs):
  1045. return self._dyn_proxy(tx, *args, **kwargs)
  1046. if isinstance(obj, variables.IteratorVariable):
  1047. # For non-list iterators, we will guard on vars that
  1048. # determine the control flow
  1049. return obj
  1050. cls = variables.BaseListVariable.cls_for(self.fn)
  1051. if obj is None:
  1052. return cls(
  1053. [],
  1054. mutable_local=MutableLocal(),
  1055. )
  1056. elif obj.has_unpack_var_sequence(tx):
  1057. if obj.source and not is_constant_source(obj.source):
  1058. if isinstance(obj, TupleIteratorVariable):
  1059. install_guard(
  1060. obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
  1061. )
  1062. else:
  1063. if (
  1064. getattr(obj, "source", False)
  1065. and isinstance(obj, ConstDictVariable)
  1066. and not istype(obj, SetVariable)
  1067. ):
  1068. tx.output.guard_on_key_order.add(obj.source.name())
  1069. install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
  1070. return cls(
  1071. list(obj.unpack_var_sequence(tx)),
  1072. mutable_local=MutableLocal(),
  1073. )
  1074. def call_iter(self, tx, obj, *args, **kwargs):
  1075. # Handle the case where we are iterating over a tuple, list or iterator
  1076. ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
  1077. if ret is None:
  1078. # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
  1079. # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call
  1080. # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
  1081. return obj.call_method(tx, "__iter__", args, kwargs)
  1082. return ret
  1083. call_tuple = _call_iter_tuple_list
  1084. call_list = _call_iter_tuple_list
  1085. def call_callable(self, tx, arg):
  1086. from .functions import BaseUserFunctionVariable
  1087. from .nn_module import NNModuleVariable
  1088. if isinstance(
  1089. arg,
  1090. (
  1091. variables.UserDefinedClassVariable,
  1092. BaseUserFunctionVariable,
  1093. NNModuleVariable,
  1094. ),
  1095. ):
  1096. return variables.ConstantVariable.create(True)
  1097. elif isinstance(arg, UserDefinedVariable):
  1098. return variables.ConstantVariable.create(callable(arg.value))
  1099. elif isinstance(arg, (ConstantVariable, SymNodeVariable, TensorVariable)):
  1100. return variables.ConstantVariable.create(False)
  1101. def call_cast(self, _, *args, **kwargs):
  1102. if len(args) == 2:
  1103. return args[1]
  1104. unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}")
  1105. def call_dict(self, tx, *args, **kwargs):
  1106. return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
  1107. @staticmethod
  1108. def call_custom_dict(tx, user_cls, *args, **kwargs):
  1109. if not kwargs:
  1110. if not args:
  1111. args = ({},)
  1112. assert len(args) == 1
  1113. arg = args[0]
  1114. if isinstance(arg, dict):
  1115. return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal())
  1116. elif isinstance(arg, variables.ConstDictVariable):
  1117. return arg.clone(user_cls=user_cls, mutable_local=MutableLocal())
  1118. elif isinstance(
  1119. arg,
  1120. (
  1121. ListVariable,
  1122. TupleVariable,
  1123. ListIteratorVariable,
  1124. ),
  1125. ):
  1126. items = dict(
  1127. x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx)
  1128. )
  1129. return ConstDictVariable(items, user_cls, mutable_local=MutableLocal())
  1130. elif not args and kwargs:
  1131. items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
  1132. return variables.ConstDictVariable(
  1133. items, user_cls=user_cls, mutable_local=MutableLocal()
  1134. )
  1135. unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
  1136. @staticmethod
  1137. def call_custom_dict_fromkeys(tx, user_cls, *args, **kwargs):
  1138. assert user_cls in {dict, OrderedDict, defaultdict}
  1139. if kwargs:
  1140. # Only `OrderedDict.fromkeys` accepts `value` passed by keyword
  1141. assert user_cls is OrderedDict
  1142. assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
  1143. args = (*args, kwargs.pop("value"))
  1144. if len(args) == 0:
  1145. raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
  1146. if len(args) == 1:
  1147. args = (*args, ConstantVariable.create(None))
  1148. assert len(args) == 2
  1149. arg, value = args
  1150. DictVariableType = (
  1151. ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
  1152. )
  1153. if isinstance(arg, dict):
  1154. arg = [ConstantVariable.create(k) for k in arg.keys()]
  1155. return DictVariableType(
  1156. dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
  1157. )
  1158. elif arg.has_unpack_var_sequence(tx) and all(
  1159. is_hashable(v) for v in arg.unpack_var_sequence(tx)
  1160. ):
  1161. keys = arg.unpack_var_sequence(tx)
  1162. return DictVariableType(
  1163. dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
  1164. )
  1165. unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
  1166. def call_set(self, tx, *args, **kwargs):
  1167. # Can we merge this implementation and call_dict's one?
  1168. assert not kwargs
  1169. if not args:
  1170. return SetVariable([], mutable_local=MutableLocal())
  1171. assert len(args) == 1
  1172. arg = args[0]
  1173. if isinstance(arg, variables.SetVariable):
  1174. return arg.clone(mutable_local=MutableLocal())
  1175. elif arg.has_unpack_var_sequence(tx):
  1176. items = arg.unpack_var_sequence(tx)
  1177. return SetVariable(items, mutable_local=MutableLocal())
  1178. else:
  1179. unimplemented(f"set(): {args} {kwargs}")
  1180. def call_zip(self, tx, *args, **kwargs):
  1181. if kwargs:
  1182. assert len(kwargs) == 1 and "strict" in kwargs
  1183. if all(x.has_unpack_var_sequence(tx) for x in args):
  1184. unpacked = [arg.unpack_var_sequence(tx) for arg in args]
  1185. if kwargs.pop("strict", False) and len(unpacked) > 0:
  1186. if not all(len(u) == len(unpacked[0]) for u in unpacked):
  1187. raise UserError(
  1188. ValueError,
  1189. "zip() has one argument of len differing from others",
  1190. )
  1191. items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
  1192. return variables.TupleVariable(items)
  1193. def call_enumerate(self, tx, *args):
  1194. if len(args) == 1:
  1195. start = 0
  1196. else:
  1197. assert len(args) == 2
  1198. assert isinstance(args[1], variables.ConstantVariable)
  1199. start = args[1].as_python_constant()
  1200. if args[0].has_unpack_var_sequence(tx):
  1201. items = [
  1202. variables.TupleVariable(
  1203. [variables.ConstantVariable.create(idx), var],
  1204. )
  1205. for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
  1206. ]
  1207. return variables.TupleVariable(items)
  1208. def call_len(self, tx, *args, **kwargs):
  1209. return args[0].call_method(tx, "__len__", args[1:], kwargs)
  1210. def call_getitem(self, tx, *args, **kwargs):
  1211. return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  1212. def call_isinstance(self, tx, arg, isinstance_type):
  1213. try:
  1214. arg_type = arg.python_type()
  1215. except NotImplementedError:
  1216. unimplemented(
  1217. f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}"
  1218. )
  1219. isinstance_type = isinstance_type.as_python_constant()
  1220. if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
  1221. def _tensor_isinstance(tensor_var, tensor_type):
  1222. def check_type(ty):
  1223. if ty not in tensortype_to_dtype:
  1224. return issubclass(arg.python_type(), ty)
  1225. dtypes = tensortype_to_dtype[ty]
  1226. return arg.dtype in dtypes
  1227. if type(tensor_type) is tuple:
  1228. return any(check_type(ty) for ty in tensor_type)
  1229. else:
  1230. return check_type(tensor_type)
  1231. return variables.ConstantVariable.create(
  1232. _tensor_isinstance(arg, isinstance_type)
  1233. )
  1234. # UserDefinedObject with C extensions can have torch.Tensor attributes,
  1235. # so break graph.
  1236. if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
  1237. arg.value, types.MemberDescriptorType
  1238. ):
  1239. unimplemented(
  1240. f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
  1241. )
  1242. # handle __instancecheck__ defined in user class
  1243. if (
  1244. isinstance(arg, variables.UserDefinedObjectVariable)
  1245. and "__instancecheck__" in isinstance_type.__class__.__dict__
  1246. ):
  1247. return variables.ConstantVariable.create(
  1248. isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
  1249. )
  1250. try:
  1251. val = issubclass(arg_type, isinstance_type)
  1252. except TypeError:
  1253. val = arg_type is isinstance_type
  1254. return variables.ConstantVariable.create(val)
  1255. def call_issubclass(self, tx, left_ty, right_ty):
  1256. """Checks if first arg is subclass of right arg"""
  1257. try:
  1258. left_ty_py = left_ty.as_python_constant()
  1259. right_ty_py = right_ty.as_python_constant()
  1260. except NotImplementedError:
  1261. unimplemented(
  1262. f"call_issubclass args not constant left_ty: {left_ty}, right_ty: {right_ty}"
  1263. )
  1264. return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
  1265. def call_super(self, tx, a, b):
  1266. return variables.SuperVariable(a, b)
  1267. def call_next(self, tx, arg: VariableTracker):
  1268. try:
  1269. return arg.next_variable(tx)
  1270. except Unsupported as ex:
  1271. if isinstance(arg, variables.BaseListVariable):
  1272. ex.remove_from_stats()
  1273. return arg.items[0]
  1274. raise
  1275. def call_hasattr(self, tx, obj, attr):
  1276. if attr.is_python_constant():
  1277. name = attr.as_python_constant()
  1278. if isinstance(obj, variables.BuiltinVariable):
  1279. return variables.ConstantVariable(hasattr(obj.fn, name))
  1280. return obj.call_hasattr(tx, name)
  1281. def call_map(self, tx, fn, seq):
  1282. if seq.has_unpack_var_sequence(tx):
  1283. items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
  1284. return variables.TupleVariable(items)
  1285. def call_sum(self, tx, seq, start=_SENTINEL):
  1286. # Special case for sum on tuple of floats and ints
  1287. if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all(
  1288. isinstance(x, variables.ConstantVariable)
  1289. and isinstance(x.value, (int, float))
  1290. for x in seq.items
  1291. ):
  1292. if start is self._SENTINEL:
  1293. return variables.ConstantVariable.create(
  1294. sum(x.value for x in seq.items),
  1295. )
  1296. if isinstance(start, variables.ConstantVariable) and isinstance(
  1297. start.value, (int, float)
  1298. ):
  1299. return variables.ConstantVariable.create(
  1300. sum((x.value for x in seq.items), start=start.value),
  1301. )
  1302. if seq.has_unpack_var_sequence(tx):
  1303. if start is self._SENTINEL:
  1304. start = variables.ConstantVariable.create(0)
  1305. items = seq.unpack_var_sequence(tx)
  1306. return BuiltinVariable(functools.reduce).call_function(
  1307. tx,
  1308. [
  1309. BuiltinVariable(operator.add),
  1310. variables.TupleVariable(items),
  1311. start,
  1312. ],
  1313. {},
  1314. )
  1315. def call_StopIteration(self, tx, *args):
  1316. return variables.StopIterationVariable([*args])
  1317. def call_reduce(self, tx, function, iterable, initial=_SENTINEL):
  1318. if iterable.has_unpack_var_sequence(tx):
  1319. items = iterable.unpack_var_sequence(tx)
  1320. if initial is self._SENTINEL:
  1321. value, items = items[0], items[1:]
  1322. else:
  1323. value = initial
  1324. for element in items:
  1325. value = function.call_function(tx, [value, element], {})
  1326. return value
  1327. def call_getattr(
  1328. self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
  1329. ):
  1330. from .. import trace_rules
  1331. from . import (
  1332. ConstantVariable,
  1333. GetAttrVariable,
  1334. PythonModuleVariable,
  1335. TorchInGraphFunctionVariable,
  1336. UserFunctionVariable,
  1337. )
  1338. from .builder import SourcelessBuilder, VariableBuilder
  1339. name = name_var.as_python_constant()
  1340. if not name_var.is_python_constant():
  1341. unimplemented("non-const getattr() name")
  1342. if tx.output.side_effects.is_attribute_mutation(obj):
  1343. if isinstance(obj, variables.UnspecializedNNModuleVariable):
  1344. if (
  1345. name
  1346. in (
  1347. "named_parameters",
  1348. "parameters",
  1349. "named_buffers",
  1350. "buffers",
  1351. "named_modules",
  1352. "modules",
  1353. )
  1354. and obj.is_state_mutated
  1355. and tx.output.side_effects.has_pending_mutation(obj)
  1356. ):
  1357. unimplemented(
  1358. f"pending mutation on nn module, so graph breaking at {name!r} call"
  1359. )
  1360. try:
  1361. # re-read a pending side effect?
  1362. return tx.output.side_effects.load_attr(obj, name)
  1363. except KeyError:
  1364. pass
  1365. if default is not None:
  1366. hasattr_var = self.call_hasattr(tx, obj, name_var)
  1367. assert hasattr_var.as_python_constant() in (True, False)
  1368. if not hasattr_var.as_python_constant():
  1369. return default
  1370. options = {}
  1371. if obj.source:
  1372. source = AttrSource(obj.source, name)
  1373. options["source"] = source
  1374. else:
  1375. source = None
  1376. if name == "__bases__":
  1377. try:
  1378. value = obj.as_python_constant()
  1379. if isinstance(value, type):
  1380. bases = value.__bases__
  1381. if source is not None:
  1382. tuple_args = [
  1383. VariableBuilder(tx, GetItemSource(source, i))(b)
  1384. for i, b in enumerate(bases)
  1385. ]
  1386. else:
  1387. tuple_args = [SourcelessBuilder.create(tx, b) for b in bases]
  1388. return variables.TupleVariable(tuple_args, **options)
  1389. except NotImplementedError:
  1390. pass
  1391. if isinstance(obj, variables.NNModuleVariable):
  1392. return obj.var_getattr(tx, name)
  1393. elif isinstance(
  1394. obj,
  1395. (
  1396. variables.TensorVariable,
  1397. variables.NamedTupleVariable,
  1398. variables.ConstantVariable,
  1399. variables.DistributedVariable,
  1400. variables.UserDefinedClassVariable,
  1401. variables.UserDefinedObjectVariable,
  1402. ),
  1403. ):
  1404. try:
  1405. return obj.var_getattr(tx, name)
  1406. except NotImplementedError:
  1407. return GetAttrVariable(obj, name, **options)
  1408. elif isinstance(obj, TorchInGraphFunctionVariable):
  1409. # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
  1410. member = getattr(obj.value, name)
  1411. if isinstance(
  1412. member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
  1413. ) and trace_rules.is_aten_op_or_tensor_method(member):
  1414. return TorchInGraphFunctionVariable(member, **options)
  1415. elif isinstance(obj, (PythonModuleVariable, DummyModule)):
  1416. if obj.is_torch or name not in obj.value.__dict__:
  1417. member = getattr(obj.value, name)
  1418. else:
  1419. member = obj.value.__dict__[name]
  1420. if config.replay_record_enabled:
  1421. tx.exec_recorder.record_module_access(obj.value, name, member)
  1422. if source is not None:
  1423. return VariableBuilder(tx, source)(member)
  1424. else:
  1425. return SourcelessBuilder.create(tx, member)
  1426. elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
  1427. return ConstantVariable.create(getattr(obj.fn, name))
  1428. else:
  1429. try:
  1430. return obj.var_getattr(tx, name)
  1431. except NotImplementedError:
  1432. return GetAttrVariable(obj, name, **options)
  1433. def call_setattr(
  1434. self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
  1435. ):
  1436. if isinstance(
  1437. obj,
  1438. (
  1439. variables.DataClassVariable,
  1440. variables.CustomizedDictVariable,
  1441. variables.PlacementVariable,
  1442. variables.UserDefinedObjectVariable,
  1443. ),
  1444. ):
  1445. return obj.call_method(tx, "__setattr__", [name_var, val], {})
  1446. elif (
  1447. tx.output.side_effects.is_attribute_mutation(obj)
  1448. and name_var.is_python_constant()
  1449. ):
  1450. name = name_var.as_python_constant()
  1451. if isinstance(obj, variables.TensorVariable):
  1452. from .builder import wrap_fx_proxy
  1453. if name == "requires_grad":
  1454. # TODO(voz): Make it work properly
  1455. unimplemented(
  1456. "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
  1457. "the middle of the graph, which aot_autograd does not currently know how to handle. "
  1458. )
  1459. if name == "data":
  1460. # Remove the old reference in tracked fakes - if we don't do this
  1461. # new .data value size and shape differences will cause
  1462. # tracked fakes to produce incorrect guards. This is sound because the TensorVariable
  1463. # coming out of set_() below will be a new one, and get
  1464. # installed in tracked fakes.
  1465. to_remove = []
  1466. for tf in tx.output.tracked_fakes:
  1467. if tf.source == obj.source:
  1468. to_remove.append(tf)
  1469. for tf in to_remove:
  1470. tx.output.tracked_fakes.remove(tf)
  1471. # Step 1 - disable grads
  1472. with dynamo_disable_grad(tx), torch.no_grad():
  1473. # Step 2 - call `set_`
  1474. out = wrap_fx_proxy(
  1475. tx,
  1476. tx.output.create_proxy(
  1477. "call_function",
  1478. torch.Tensor.set_,
  1479. *proxy_args_kwargs([obj, val], {}),
  1480. ),
  1481. )
  1482. # Step 3 - drop the version counter - this is a step required to get
  1483. # .data setting to play correctly with the autograd engine.
  1484. # Essentially, dynamo is trying to faithfully preserve the (absurd)
  1485. # behavior of .data= from eager mode
  1486. def _lower_version_count_by_1(x):
  1487. version = x._version
  1488. if version > 0:
  1489. version = version - 1
  1490. torch._C._autograd._unsafe_set_version_counter(x, version)
  1491. return x
  1492. tx.output.create_proxy(
  1493. "call_function",
  1494. _lower_version_count_by_1,
  1495. (out.as_proxy(),),
  1496. {},
  1497. )
  1498. _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"])
  1499. # This handles options prop, guards and ends with a clone
  1500. # Step 4 - replace all reference to the current object with the new one
  1501. return out
  1502. tx.output.side_effects.store_attr(obj, name, val)
  1503. if name == "_grad":
  1504. tx.output.side_effects.store_attr(obj, "grad", val)
  1505. return val
  1506. elif isinstance(obj, variables.UserDefinedObjectVariable):
  1507. unimplemented(
  1508. f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
  1509. )
  1510. elif isinstance(obj, variables.NNModuleVariable):
  1511. if not tx.output.is_root_tracer():
  1512. raise AttributeMutationError(
  1513. "Can't inplace modify module params/buffers inside HigherOrderOp"
  1514. )
  1515. if name_var.is_python_constant() and isinstance(
  1516. val, variables.TensorVariable
  1517. ):
  1518. assigning_fake_val = get_fake_value(val.as_proxy().node, tx)
  1519. try:
  1520. getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
  1521. except AttributeError:
  1522. getattr_var = None
  1523. if isinstance(getattr_var, variables.TensorVariable):
  1524. # get_fake_val will get the same fake tensor
  1525. existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)
  1526. # same tensor identiy, setattr is a no-op
  1527. mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
  1528. if (
  1529. existing_fake_attr is assigning_fake_val
  1530. and mod_setattr is torch.nn.Module.__setattr__
  1531. ):
  1532. return getattr_var
  1533. obj.convert_to_unspecialized(tx)
  1534. # FIXME (tmanlaibaatar) this is utter hack to unblock HuggingFace export
  1535. # Export generally doesn't want to allow mutations on objects directly,
  1536. # but we don't have good way to do this rn. For now, we make it an undefined
  1537. # behaviour and just set attributes directly on the PretrainedConfig object
  1538. # for now.
  1539. elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export:
  1540. if name_var.is_python_constant() and isinstance(
  1541. val, variables.ConstantVariable
  1542. ):
  1543. setattr(
  1544. obj.obj, name_var.as_python_constant(), val.as_python_constant()
  1545. )
  1546. return ConstantVariable(None)
  1547. def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker):
  1548. return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
  1549. def call_type(self, tx, obj: VariableTracker):
  1550. from .builder import SourcelessBuilder, VariableBuilder
  1551. try:
  1552. py_type = obj.python_type()
  1553. except NotImplementedError as error:
  1554. raise UserError(
  1555. UserErrorType.INVALID_INPUT,
  1556. str(error),
  1557. case_name="unknown_python_type",
  1558. ) from None
  1559. if obj.source is None:
  1560. return SourcelessBuilder.create(tx, py_type)
  1561. else:
  1562. return VariableBuilder(tx, TypeSource(obj.source))(py_type)
  1563. def call_reversed(self, tx, obj: VariableTracker):
  1564. if obj.has_unpack_var_sequence(tx):
  1565. items = list(reversed(obj.unpack_var_sequence(tx)))
  1566. return variables.TupleVariable(items)
  1567. def call_sorted(self, tx, obj: VariableTracker, **kwargs):
  1568. if (
  1569. obj.has_unpack_var_sequence(tx)
  1570. and not isinstance(obj, variables.TensorVariable)
  1571. and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx))
  1572. ):
  1573. function = kwargs.pop("key", None)
  1574. reverse = kwargs.pop(
  1575. "reverse", ConstantVariable.create(False)
  1576. ).as_python_constant()
  1577. assert len(kwargs) == 0
  1578. if function:
  1579. items = sorted(
  1580. obj.unpack_var_sequence(tx),
  1581. key=lambda x: function.call_function(
  1582. tx, [x], {}
  1583. ).as_python_constant(),
  1584. reverse=reverse,
  1585. )
  1586. else:
  1587. items = sorted(
  1588. obj.unpack_var_sequence(tx),
  1589. key=lambda x: x.as_python_constant(),
  1590. reverse=reverse,
  1591. )
  1592. return variables.ListVariable(items)
  1593. def call_chain(self, tx, *args):
  1594. if all(obj.has_unpack_var_sequence(tx) for obj in args):
  1595. items = []
  1596. for obj in args:
  1597. items.extend(obj.unpack_var_sequence(tx))
  1598. return variables.TupleVariable(items)
  1599. def call_islice(self, tx, iterable, *args):
  1600. if iterable.has_unpack_var_sequence(tx) and all(
  1601. x.is_python_constant() for x in args
  1602. ):
  1603. const_args = [x.as_python_constant() for x in args]
  1604. items = iterable.unpack_var_sequence(tx)
  1605. items = list(itertools.islice(items, *const_args))
  1606. return variables.TupleVariable(items)
  1607. # neg is a constant fold function, so we only get here if constant fold is not valid
  1608. def call_neg(self, tx, a):
  1609. if isinstance(a, SymNodeVariable):
  1610. return SymNodeVariable.create(
  1611. tx,
  1612. (operator.neg)(a.as_proxy()),
  1613. sym_num=None,
  1614. )
  1615. # None no-ops this handler and lets the driving function proceed
  1616. return None
  1617. def call_format(self, tx, _format_string, *args, **kwargs):
  1618. format_string = _format_string.as_python_constant()
  1619. return variables.StringFormatVariable.create(format_string, args, kwargs)
  1620. def call_id(self, tx, *args):
  1621. if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
  1622. nn_mod_variable = args[0]
  1623. mod = tx.output.get_submodule(nn_mod_variable.module_key)
  1624. return variables.ConstantVariable.create(id(mod))
  1625. elif len(args) == 1 and isinstance(
  1626. args[0], variables.UserDefinedObjectVariable
  1627. ):
  1628. install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH))
  1629. constant_result = id(args[0].value)
  1630. return variables.ConstantVariable.create(constant_result)
  1631. else:
  1632. unimplemented(f"call_id with args {args}")
  1633. def call_deepcopy(self, tx, x):
  1634. unimplemented(f"copy.deepcopy {repr(x)}")
  1635. def _comparison_with_tensor(self, tx, left, right):
  1636. from .builder import wrap_fx_proxy_cls
  1637. from .tensor import supported_tensor_comparison_op_values
  1638. op = self.fn
  1639. if op in [operator.is_, operator.is_not]:
  1640. is_result = (
  1641. isinstance(left, TensorVariable)
  1642. and isinstance(right, TensorVariable)
  1643. and id(extract_fake_example_value(left.as_proxy().node))
  1644. == id(extract_fake_example_value(right.as_proxy().node))
  1645. )
  1646. if op is operator.is_:
  1647. return ConstantVariable.create(is_result)
  1648. else:
  1649. return ConstantVariable.create(not is_result)
  1650. if op not in supported_tensor_comparison_op_values:
  1651. unimplemented(f"{op.__name__}({left}, {right})")
  1652. if (
  1653. isinstance(left, TensorVariable)
  1654. and isinstance(right, TensorVariable)
  1655. and (left.size and right.size) is not None
  1656. and left.size != right.size
  1657. ):
  1658. try:
  1659. torch.broadcast_shapes(left.size, right.size)
  1660. except RuntimeError:
  1661. # not broadcastable, can't be compared
  1662. unimplemented(f"{op.__name__}({left}, {right})")
  1663. tensor_cls = left if isinstance(left, TensorVariable) else right
  1664. proxy = tx.output.create_proxy(
  1665. "call_function", op, (left.as_proxy(), right.as_proxy()), {}
  1666. )
  1667. return wrap_fx_proxy_cls(
  1668. type(tensor_cls), # handle Ndarrays and Tensors
  1669. tx,
  1670. proxy,
  1671. )
  1672. def _comparison_with_symnode(self, tx, left, right):
  1673. from .tensor import supported_tensor_comparison_op_values
  1674. op = self.fn
  1675. if op not in supported_tensor_comparison_op_values:
  1676. unimplemented(f"{op.__name__}({left}, {right})")
  1677. proxy = tx.output.create_proxy(
  1678. "call_function", op, (left.as_proxy(), right.as_proxy()), {}
  1679. )
  1680. return SymNodeVariable.create(
  1681. tx,
  1682. proxy,
  1683. sym_num=None,
  1684. )
  1685. def call_and_(self, tx, a, b):
  1686. # Rely on constant_handler
  1687. if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
  1688. return None
  1689. if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
  1690. b, (SymNodeVariable, ConstantVariable)
  1691. ):
  1692. return SymNodeVariable.create(
  1693. tx,
  1694. tx.output.create_proxy(
  1695. "call_function", operator.and_, *proxy_args_kwargs([a, b], {})
  1696. ),
  1697. sym_num=None,
  1698. )
  1699. if hasattr(a, "set_items") and hasattr(b, "set_items"):
  1700. return SetVariable(list(a.set_items & b.set_items))
  1701. # None no-ops this handler and lets the driving function proceed
  1702. def call_or_(self, tx, a, b):
  1703. # Rely on constant_handler
  1704. if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
  1705. return None
  1706. if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
  1707. b, (SymNodeVariable, ConstantVariable)
  1708. ):
  1709. return SymNodeVariable.create(
  1710. tx,
  1711. tx.output.create_proxy(
  1712. "call_function", operator.or_, *proxy_args_kwargs([a, b], {})
  1713. ),
  1714. sym_num=None,
  1715. )
  1716. if hasattr(a, "set_items") and hasattr(b, "set_items"):
  1717. return SetVariable(list(a.set_items | b.set_items))
  1718. # None no-ops this handler and lets the driving function proceed
  1719. return None
  1720. def call_not_(self, tx, a):
  1721. if isinstance(a, SymNodeVariable):
  1722. return SymNodeVariable.create(
  1723. tx,
  1724. tx.output.create_proxy(
  1725. "call_function", operator.not_, *proxy_args_kwargs([a], {})
  1726. ),
  1727. sym_num=None,
  1728. )
  1729. # Unwrap the underlying ConstDictVariable
  1730. if isinstance(a, DictView):
  1731. a = a.dv_dict
  1732. if isinstance(a, (ListVariable, ConstDictVariable)):
  1733. return ConstantVariable.create(len(a.items) == 0)
  1734. return None
  1735. def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
  1736. return a.call_method(tx, "__contains__", [b], {})
  1737. call_all = _polyfill_call_impl("all")
  1738. call_any = _polyfill_call_impl("any")
  1739. @contextlib.contextmanager
  1740. def dynamo_disable_grad(tx):
  1741. from . import GradModeVariable
  1742. org_value = torch.is_grad_enabled()
  1743. gmv = GradModeVariable.create(tx, False)
  1744. try:
  1745. gmv.enter(tx)
  1746. yield
  1747. finally:
  1748. gmv.exit(tx)