source.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import dataclasses
  4. import enum
  5. from typing import Any, Optional, Union
  6. from torch._guards import ChainedSource, GuardSource, Source
  7. from . import utils
  8. from .bytecode_transformation import create_call_function, create_instruction
  9. from .utils import enum_repr
  10. # It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
  11. # so those cases are omitted intentionally
  12. _GUARD_SOURCE_NN_MODULE = {
  13. GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
  14. GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
  15. GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
  16. GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
  17. }
  18. _GUARD_SOURCE_FSDP_MODULE = {
  19. GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
  20. GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
  21. GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  22. GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  23. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  24. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  25. }
  26. _GUARD_SOURCE_NOT_NN_MODULE = {
  27. GuardSource.LOCAL: GuardSource.LOCAL,
  28. GuardSource.GLOBAL: GuardSource.GLOBAL,
  29. GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
  30. GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
  31. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
  32. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
  33. }
  34. def is_constant_source(source):
  35. if isinstance(source, ConstantSource):
  36. return True
  37. try:
  38. if source.guard_source() == GuardSource.CONSTANT:
  39. return True
  40. except NotImplementedError:
  41. pass
  42. return False
  43. def reconstruct_getitem(
  44. source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
  45. ):
  46. source.base.reconstruct(codegen)
  47. if isinstance(source.index, Source):
  48. source.index.reconstruct(codegen)
  49. else:
  50. if index_is_slice:
  51. assert isinstance(source, GetItemSource)
  52. codegen.append_output(codegen.create_load_const(source.unpack_slice()))
  53. else:
  54. codegen.append_output(codegen.create_load_const(source.index))
  55. @dataclasses.dataclass(frozen=True)
  56. class LocalSource(Source):
  57. local_name: str
  58. cell_or_freevar: bool = False
  59. def reconstruct(self, codegen):
  60. codegen.append_output(codegen.create_load(self.local_name))
  61. def guard_source(self):
  62. return GuardSource.LOCAL
  63. def name(self):
  64. return f"L[{repr(self.local_name)}]"
  65. @dataclasses.dataclass(frozen=True)
  66. class SyntheticLocalSource(Source):
  67. local_name: str
  68. def reconstruct(self, codegen):
  69. codegen.append_output(codegen.create_load(self.local_name))
  70. def guard_source(self):
  71. return GuardSource.SYNTHETIC_LOCAL
  72. def name(self):
  73. return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
  74. @dataclasses.dataclass(frozen=True)
  75. class RandomValueSource(Source):
  76. random_call_index: int
  77. def guard_source(self):
  78. return GuardSource.RANDOM_VALUE
  79. def reconstruct(self, codegen):
  80. codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
  81. codegen.append_output(codegen.create_load_const(self.random_call_index))
  82. codegen.append_output(create_instruction("BINARY_SUBSCR"))
  83. def name(self):
  84. return f"random_value_{self.random_call_index}"
  85. @dataclasses.dataclass(frozen=True)
  86. class GlobalSource(Source):
  87. global_name: str
  88. def reconstruct(self, codegen):
  89. codegen.append_output(
  90. codegen.create_load_global(self.global_name, False, add=True)
  91. )
  92. def guard_source(self):
  93. return GuardSource.GLOBAL
  94. def name(self):
  95. return f"G[{repr(self.global_name)}]"
  96. @dataclasses.dataclass(frozen=True)
  97. class GlobalWeakRefSource(Source):
  98. global_name: str
  99. def reconstruct(self, codegen):
  100. codegen.append_output(
  101. codegen.create_load_global(self.global_name, True, add=True)
  102. )
  103. codegen.extend_output(create_call_function(0, False))
  104. def guard_source(self):
  105. return GuardSource.GLOBAL
  106. def name(self):
  107. return f"G[{repr(self.global_name)}]()"
  108. @dataclasses.dataclass(frozen=True)
  109. class AttrSource(ChainedSource):
  110. member: str
  111. def __post_init__(self):
  112. assert self.base, "Can't construct an AttrSource without a valid base source"
  113. if "." in self.member:
  114. member_parts = self.member.split(".")
  115. object.__setattr__(
  116. self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
  117. )
  118. object.__setattr__(self, "member", member_parts[-1])
  119. def reconstruct(self, codegen):
  120. self.base.reconstruct(codegen)
  121. codegen.extend_output(codegen.create_load_attrs(self.member))
  122. def guard_source(self):
  123. return self.base.guard_source()
  124. def name(self):
  125. if not self.member.isidentifier():
  126. return f"getattr({self.base.name()}, {self.member!r})"
  127. return f"{self.base.name()}.{self.member}"
  128. # Represents tensor.grad source. It could be represented by AttrSource as well.
  129. # But, we could access grad field on tensor directly in C++ without going
  130. # through the Python bytecodes. Therefore, we use a separate source for grad
  131. # field.
  132. @dataclasses.dataclass(frozen=True)
  133. class GradSource(ChainedSource):
  134. member: str = "grad"
  135. def reconstruct(self, codegen):
  136. self.base.reconstruct(codegen)
  137. codegen.extend_output(codegen.create_load_attrs(self.member))
  138. def guard_source(self):
  139. return self.base.guard_source()
  140. def name(self):
  141. return f"{self.base.name()}.{self.member}"
  142. @dataclasses.dataclass(frozen=True)
  143. class ParamBufferSource(AttrSource):
  144. def guard_source(self):
  145. return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
  146. # This source is intended to be used in places where a source is needed but it is expected
  147. # that the symbol will be simplified out later on. Symbols with ephemeral sources are
  148. # prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
  149. # source. Guarding on this source is an error.
  150. #
  151. # Example: During subclass view fake-ification, any close-over ViewFunc state should be
  152. # symbolicized / fake-ified to avoid invalid specialization during view replay. This source
  153. # is useful for symbols utilized in the middle of the view chain that are not expected to be
  154. # present within the final view shape metadata.
  155. @dataclasses.dataclass(frozen=True)
  156. class EphemeralSource(Source):
  157. desc: Optional[str] = None
  158. def guard_source(self):
  159. return GuardSource.EPHEMERAL
  160. def name(self):
  161. return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
  162. def make_guard(self):
  163. raise NotImplementedError
  164. def is_ephemeral(self):
  165. return True
  166. class TensorProperty(enum.Enum):
  167. SIZE = 0
  168. STRIDE = 1
  169. STORAGE_OFFSET = 2
  170. def method_name(self):
  171. if self is TensorProperty.SIZE:
  172. return "size"
  173. elif self is TensorProperty.STRIDE:
  174. return "stride"
  175. elif self is TensorProperty.STORAGE_OFFSET:
  176. return "storage_offset"
  177. @dataclasses.dataclass(frozen=True)
  178. class TensorPropertySource(ChainedSource):
  179. prop: TensorProperty
  180. idx: Optional[int] = None # None for STORAGE_OFFSET
  181. def __post_init__(self):
  182. assert self.base is not None
  183. if self.prop is TensorProperty.STORAGE_OFFSET:
  184. assert self.idx is None
  185. else:
  186. assert self.idx is not None
  187. def reconstruct(self, codegen):
  188. self.base.reconstruct(codegen)
  189. codegen.append_output(codegen.create_load_attr(self.prop.method_name()))
  190. if self.idx is not None:
  191. codegen.append_output(codegen.create_load_const(self.idx))
  192. codegen.extend_output(
  193. create_call_function(1 if self.idx is not None else 0, True)
  194. )
  195. def guard_source(self):
  196. return self.base.guard_source()
  197. def name(self):
  198. if self.prop is TensorProperty.SIZE:
  199. return f"{self.base.name()}.size()[{self.idx}]"
  200. elif self.prop is TensorProperty.STRIDE:
  201. return f"{self.base.name()}.stride()[{self.idx}]"
  202. elif self.prop is TensorProperty.STORAGE_OFFSET:
  203. assert self.idx is None
  204. return f"{self.base.name()}.storage_offset()"
  205. else:
  206. raise AssertionError(f"unhandled {self.prop}")
  207. @dataclasses.dataclass(frozen=True)
  208. class NegateSource(ChainedSource):
  209. def __post_init__(self):
  210. assert self.base is not None
  211. def reconstruct(self, codegen):
  212. raise NotImplementedError
  213. def guard_source(self):
  214. return self.base.guard_source()
  215. def name(self):
  216. # NB: use method call so that function stripping regexes work
  217. return f"{self.base.name()}.__neg__()"
  218. @dataclasses.dataclass(frozen=True)
  219. class ConvertIntSource(ChainedSource):
  220. def __post_init__(self):
  221. assert self.base is not None
  222. def reconstruct(self, codegen):
  223. self.base.reconstruct(codegen)
  224. def guard_source(self):
  225. return self.base.guard_source()
  226. def name(self):
  227. return f"cast_symbool_to_symint_guardless({self.base.name()})"
  228. @dataclasses.dataclass(frozen=True)
  229. class FlattenScriptObjectSource(ChainedSource):
  230. def __post_init__(self):
  231. assert self.base is not None
  232. def reconstruct(self, codegen):
  233. self.base.reconstruct(codegen)
  234. def guard_source(self):
  235. return self.base.guard_source()
  236. def name(self):
  237. return f"{self.base.name()}.__obj_flatten__()"
  238. @dataclasses.dataclass(frozen=True)
  239. class ScriptObjectQualifiedNameSource(ChainedSource):
  240. def __post_init__(self):
  241. assert self.base is not None
  242. def reconstruct(self, codegen):
  243. self.base.reconstruct(codegen)
  244. def guard_source(self):
  245. return self.base.guard_source()
  246. def name(self):
  247. return f"{self.base.name()}._type().qualified_name()"
  248. @dataclasses.dataclass(frozen=True)
  249. class DefaultsSource(ChainedSource):
  250. idx_key: Union[int, str]
  251. is_kw: bool = False
  252. field: str = dataclasses.field(init=False, repr=False, compare=False)
  253. _name: str = dataclasses.field(init=False, repr=False, compare=False)
  254. def __post_init__(self):
  255. assert (
  256. self.base
  257. ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
  258. if self.is_kw:
  259. assert isinstance(self.idx_key, str)
  260. object.__setattr__(self, "field", "__kwdefaults__")
  261. object.__setattr__(
  262. self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
  263. )
  264. else:
  265. assert isinstance(self.idx_key, int)
  266. object.__setattr__(self, "field", "__defaults__")
  267. object.__setattr__(
  268. self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
  269. )
  270. def reconstruct(self, codegen):
  271. self.base.reconstruct(codegen)
  272. codegen.extend_output(codegen.create_load_attrs(self.field))
  273. codegen.append_output(codegen.create_load_const(self.idx_key))
  274. codegen.append_output(create_instruction("BINARY_SUBSCR"))
  275. def guard_source(self):
  276. return self.base.guard_source()
  277. def name(self):
  278. return self._name
  279. @dataclasses.dataclass(frozen=True)
  280. class GetItemSource(ChainedSource):
  281. index: Any
  282. index_is_slice: bool = False
  283. def __post_init__(self):
  284. assert self.base is not None
  285. if isinstance(self.index, slice):
  286. # store the hashable version of the slice so the whole GetItemSource is hashable
  287. super().__setattr__("index", self.index.__reduce__())
  288. super().__setattr__("index_is_slice", True)
  289. def reconstruct(self, codegen):
  290. reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice)
  291. codegen.append_output(create_instruction("BINARY_SUBSCR"))
  292. def guard_source(self):
  293. return self.base.guard_source()
  294. def unpack_slice(self):
  295. assert self.index_is_slice
  296. slice_class, slice_args = self.index
  297. return slice_class(*slice_args)
  298. def name(self):
  299. # Index can be of following types
  300. # 1) ConstDictKeySource
  301. # 2) enum.Enum
  302. # 3) index is a slice - example 1:4
  303. # 4) index is a constant - example string, integer
  304. if isinstance(self.index, Source):
  305. if not isinstance(self.index, ConstDictKeySource):
  306. raise ValueError(
  307. "GetItemSource index must be a constant, enum or ConstDictKeySource"
  308. )
  309. return f"{self.base.name()}[{self.index.name()}]"
  310. elif self.index_is_slice:
  311. return f"{self.base.name()}[{self.unpack_slice()!r}]"
  312. elif isinstance(self.index, enum.Enum):
  313. return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
  314. else:
  315. return f"{self.base.name()}[{self.index!r}]"
  316. @dataclasses.dataclass(frozen=True)
  317. class ConstDictKeySource(GetItemSource):
  318. def is_dict_key(self):
  319. return True
  320. def reconstruct(self, codegen):
  321. codegen.load_import_from(utils.__name__, "dict_keys_getitem")
  322. self.base.reconstruct(codegen)
  323. codegen.append_output(codegen.create_load_const(self.index))
  324. codegen.extend_output(create_call_function(2, True))
  325. def name(self):
  326. # The list creation will be CSE'd by PyExprCSEPass
  327. return f"list({self.base.name()}.keys())[{self.index!r}]"
  328. @dataclasses.dataclass(frozen=True)
  329. class TupleIteratorGetItemSource(GetItemSource):
  330. def reconstruct(self, codegen):
  331. codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
  332. self.base.reconstruct(codegen)
  333. codegen.append_output(codegen.create_load_const(self.index))
  334. codegen.extend_output(create_call_function(2, True))
  335. def name(self):
  336. return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
  337. @dataclasses.dataclass(frozen=True)
  338. class TypeSource(ChainedSource):
  339. def __post_init__(self):
  340. assert self.base is not None
  341. def reconstruct(self, codegen):
  342. codegen.load_import_from("builtins", "type")
  343. self.base.reconstruct(codegen)
  344. codegen.extend_output(create_call_function(1, True))
  345. def guard_source(self):
  346. return self.base.guard_source()
  347. def name(self):
  348. return f"type({self.base.name()})"
  349. @dataclasses.dataclass(frozen=True)
  350. class ODictGetItemSource(ChainedSource):
  351. index: Any
  352. def __post_init__(self):
  353. assert self.base is not None
  354. def reconstruct(self, codegen):
  355. codegen.append_output(
  356. codegen._create_load_const(collections.OrderedDict.__getitem__)
  357. )
  358. reconstruct_getitem(self, codegen, index_is_slice=False)
  359. codegen.extend_output(create_call_function(2, True))
  360. def guard_source(self):
  361. return self.base.guard_source()
  362. def name(self):
  363. if isinstance(self.index, type):
  364. rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
  365. return f"___odict_getitem({self.base.name()}, {rep})"
  366. elif isinstance(self.index, Source):
  367. return f"___odict_getitem({self.base.name()}, {self.index.name()})"
  368. else:
  369. return f"___odict_getitem({self.base.name()}, {self.index!r})"
  370. @dataclasses.dataclass(frozen=True)
  371. class OptimizerSource(ChainedSource):
  372. def reconstruct(self, codegen):
  373. self.base.reconstruct(codegen)
  374. def guard_source(self):
  375. return self.base.guard_source()
  376. def name(self):
  377. return self.base.name()
  378. @dataclasses.dataclass(frozen=True)
  379. class NNModuleSource(ChainedSource):
  380. def reconstruct(self, codegen):
  381. self.base.reconstruct(codegen)
  382. def guard_source(self):
  383. return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
  384. def name(self):
  385. return self.base.name()
  386. @dataclasses.dataclass(frozen=True)
  387. class NotNNModuleSource(NNModuleSource):
  388. def guard_source(self):
  389. return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
  390. @dataclasses.dataclass(frozen=True)
  391. class FSDPNNModuleSource(NNModuleSource):
  392. def guard_source(self):
  393. return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
  394. @dataclasses.dataclass(frozen=True)
  395. class GlobalStateSource(Source):
  396. def name(self):
  397. return ""
  398. def guard_source(self):
  399. return GuardSource.GLOBAL
  400. @dataclasses.dataclass(frozen=True)
  401. class ConstantSource(Source):
  402. source_name: str
  403. def reconstruct(self, codegen):
  404. codegen.append_output(
  405. codegen.create_load_global(self.source_name, False, add=False)
  406. )
  407. def guard_source(self):
  408. return GuardSource.CONSTANT
  409. def name(self):
  410. return self.source_name
  411. def make_guard(self, fn):
  412. raise NotImplementedError
  413. @dataclasses.dataclass(frozen=True)
  414. class NumpyTensorSource(ChainedSource):
  415. def name(self) -> str:
  416. return f"___from_numpy({self.base.name()})"
  417. def guard_source(self):
  418. return self.base.guard_source()
  419. def reconstruct(self, codegen):
  420. codegen.load_import_from("torch", "as_tensor")
  421. self.base.reconstruct(codegen)
  422. codegen.extend_output(create_call_function(1, True))
  423. # NB: We don't expect you to actually ever generate guards against this
  424. # source, it is ephemeral
  425. @dataclasses.dataclass(frozen=True)
  426. class FloatTensorSource(ChainedSource):
  427. def name(self) -> str:
  428. return f"___as_tensor({self.base.name()})"
  429. def guard_source(self):
  430. return self.base.guard_source()
  431. @dataclasses.dataclass(frozen=True)
  432. class CallMethodItemSource(ChainedSource):
  433. def name(self) -> str:
  434. return f"{self.base.name()}.item()"
  435. def guard_source(self):
  436. return self.base.guard_source()
  437. # This is a synthetic source that is associated with the singleton
  438. # shape env guard we always register for all frames. We get the actual
  439. # guard contents from the ambient ShapeEnv
  440. @dataclasses.dataclass(frozen=True)
  441. class ShapeEnvSource(Source):
  442. def name(self):
  443. return ""
  444. def guard_source(self):
  445. return GuardSource.SHAPE_ENV
  446. @dataclasses.dataclass(frozen=True)
  447. class BackwardStateSource(Source):
  448. def name(self):
  449. return ""
  450. def guard_source(self):
  451. return GuardSource.BACKWARD_STATE
  452. def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
  453. if isinstance(source, ChainedSource):
  454. return is_from_local_source(
  455. source.base, allow_cell_or_freevar=allow_cell_or_freevar
  456. )
  457. if not isinstance(source, LocalSource):
  458. return False
  459. if not allow_cell_or_freevar and source.cell_or_freevar:
  460. return False
  461. return True
  462. def is_from_flatten_script_object_source(source: Source):
  463. if isinstance(source, FlattenScriptObjectSource):
  464. return True
  465. elif isinstance(source, ChainedSource):
  466. return is_from_flatten_script_object_source(source.base)
  467. return False
  468. def is_from_optimizer_source(source: Source):
  469. if isinstance(source, OptimizerSource):
  470. return True
  471. if isinstance(source, ChainedSource):
  472. return is_from_optimizer_source(source.base)
  473. return False
  474. # TODO: can probably write a generic "test this on everything in the chain"
  475. # helper
  476. def is_from_defaults(source: Source):
  477. if isinstance(source, DefaultsSource):
  478. return True
  479. if isinstance(source, ChainedSource):
  480. return is_from_defaults(source.base)
  481. return False
  482. def is_cell_contents(source: Source):
  483. return isinstance(source, AttrSource) and source.member == "cell_contents"