constant.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # mypy: ignore-errors
  2. import operator
  3. from typing import Dict, List
  4. import torch
  5. from torch._dynamo.source import GetItemSource
  6. from .. import variables
  7. from ..exc import unimplemented, UserError, UserErrorType
  8. from ..guards import GuardBuilder, install_guard
  9. from ..utils import common_constant_types, istype, np
  10. from .base import typestr, VariableTracker
  11. _type_to_assert_reason = {
  12. # NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards.
  13. # A locally created set should always become a SetVariable, as the items in the set will already either be sourced
  14. # from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an
  15. # input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is
  16. # not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to
  17. # become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable.
  18. # As this does not exist, we cannot add sets to this invariant.
  19. list: "List types must use ListVariable.",
  20. dict: "Dict types must use ConstDictVariable.",
  21. torch.Tensor: "Tensor types must use TensorVariable.",
  22. torch.SymInt: "SymInts must use SymNodeVariable. "
  23. "If the underlying value is static, we will create a ConstantVariable and specialize.",
  24. torch.SymFloat: "SymInts must use SymNodeVariable",
  25. }
  26. class ConstantVariable(VariableTracker):
  27. @staticmethod
  28. def create(value, **kwargs) -> VariableTracker:
  29. source = kwargs.get("source", None)
  30. is_literal = ConstantVariable.is_literal(value)
  31. if not is_literal:
  32. for disallowed_type, reason in _type_to_assert_reason.items():
  33. assert not isinstance(value, disallowed_type), reason
  34. # Routing for list and tuple literals.
  35. if is_literal and isinstance(value, (list, tuple, set, frozenset)):
  36. items = []
  37. for i, x in enumerate(value):
  38. item_source = GetItemSource(source, i) if source else None
  39. if item_source:
  40. install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
  41. items.append(
  42. ConstantVariable.create(
  43. x,
  44. source=item_source,
  45. )
  46. )
  47. if isinstance(value, (list, tuple)):
  48. return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
  49. else:
  50. assert isinstance(value, (set, frozenset)), type(value)
  51. return variables.SetVariable(items)
  52. return ConstantVariable(value, **kwargs)
  53. def __init__(self, value, **kwargs):
  54. super().__init__(**kwargs)
  55. if not ConstantVariable.is_literal(value):
  56. for disallowed_type, reason in _type_to_assert_reason.items():
  57. assert not isinstance(value, disallowed_type), reason
  58. assert not isinstance(
  59. value, (list, tuple)
  60. ), "ConstantVariable(list) is banned - please create a ListVariable(items)"
  61. if np is not None and isinstance(value, np.number):
  62. self.value = value.item()
  63. else:
  64. self.value = value
  65. def as_proxy(self):
  66. return self.value
  67. def __str__(self):
  68. return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
  69. def python_type(self):
  70. return type(self.value)
  71. def as_python_constant(self):
  72. return self.value
  73. def is_python_constant(self):
  74. return True
  75. @property
  76. def items(self):
  77. """
  78. Need this when adding a BaseListVariable and a ConstantVariable together.
  79. Happens in detectron2.
  80. """
  81. return self.unpack_var_sequence(tx=None)
  82. def getitem_const(self, arg: VariableTracker):
  83. return ConstantVariable.create(
  84. self.value[arg.as_python_constant()],
  85. )
  86. @staticmethod
  87. def is_literal(obj):
  88. if type(obj) in common_constant_types:
  89. return True
  90. # The structure within is_literal get routed to variables.BaseListVariable
  91. if type(obj) in (list, tuple, set, frozenset, torch.Size):
  92. return all(ConstantVariable.is_literal(x) for x in obj)
  93. return False
  94. def unpack_var_sequence(self, tx):
  95. try:
  96. return [ConstantVariable.create(x) for x in self.as_python_constant()]
  97. except TypeError as e:
  98. raise NotImplementedError from e
  99. def const_getattr(self, tx, name):
  100. if isinstance(self.value, type):
  101. raise UserError(
  102. UserErrorType.ANTI_PATTERN,
  103. "Can't access members of type(obj) for a generated custom object. "
  104. "Please use __class__ instead",
  105. case_name="type_reflection_method",
  106. )
  107. member = getattr(self.value, name)
  108. if callable(member):
  109. raise NotImplementedError
  110. return member
  111. def call_method(
  112. self,
  113. tx,
  114. name,
  115. args: "List[VariableTracker]",
  116. kwargs: "Dict[str, VariableTracker]",
  117. ) -> "VariableTracker":
  118. from .tensor import SymNodeVariable
  119. if name == "format" and istype(self.value, str):
  120. return variables.BuiltinVariable(str.format).call_function(
  121. tx, [self, *args], kwargs
  122. )
  123. if any(isinstance(x, SymNodeVariable) for x in args):
  124. # Promote to SymNodeVariable for operations involving dynamic shapes.
  125. return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
  126. tx, name, args, kwargs
  127. )
  128. try:
  129. const_args = [a.as_python_constant() for a in args]
  130. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  131. except NotImplementedError:
  132. return super().call_method(tx, name, args, kwargs)
  133. def has_arith_binop(num_ty):
  134. return (
  135. isinstance(self.value, num_ty)
  136. and hasattr(operator, name)
  137. and len(args) == 1
  138. and args[0].is_python_constant()
  139. )
  140. if isinstance(self.value, str) and name in str.__dict__.keys():
  141. method = getattr(self.value, name)
  142. return ConstantVariable.create(method(*const_args, **const_kwargs))
  143. elif has_arith_binop(int) or has_arith_binop(float):
  144. op = getattr(operator, name)
  145. add_target = const_args[0]
  146. if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
  147. from .tensor import SymNodeVariable
  148. # Addition between a non sym and sym makes a sym
  149. # sym_num = tx.output.register_attr_or_module(
  150. # add_target, f"sym_shape_{add_target}", source=None
  151. # )
  152. proxy = tx.output.create_proxy(
  153. "call_function", op, (self.value, add_target), {}
  154. )
  155. return SymNodeVariable.create(tx, proxy, add_target)
  156. return ConstantVariable.create(op(self.value, add_target))
  157. elif name == "__len__" and not (args or kwargs):
  158. return ConstantVariable.create(len(self.value))
  159. elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
  160. assert not kwargs
  161. search = args[0].as_python_constant()
  162. result = search in self.value
  163. return ConstantVariable.create(result)
  164. unimplemented(f"const method call {typestr(self.value)}.{name}")
  165. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  166. result = hasattr(self.value, name)
  167. return variables.ConstantVariable.create(result)
  168. class EnumVariable(VariableTracker):
  169. def __init__(self, value, **kwargs):
  170. super().__init__(**kwargs)
  171. self.value = value
  172. @classmethod
  173. def create(cls, cls_type, value_vt, options):
  174. if isinstance(value_vt, variables.ConstantVariable):
  175. for member in list(cls_type):
  176. if member.value == value_vt.as_python_constant():
  177. return cls(member, **options)
  178. unimplemented("Enum variable is constructed with non constant values")
  179. def as_proxy(self):
  180. return self.value
  181. def __str__(self):
  182. return f"EnumVariable({type(self.value)})"
  183. def python_type(self):
  184. return type(self.value)
  185. def as_python_constant(self):
  186. return self.value
  187. def const_getattr(self, tx, name):
  188. member = getattr(self.value, name)
  189. if callable(member):
  190. raise NotImplementedError
  191. return member