_check.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # mypy: allow-untyped-defs
  2. import ast
  3. import inspect
  4. import textwrap
  5. import warnings
  6. import torch
  7. class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
  8. """Check the ``__init__`` method of a given ``nn.Module``.
  9. It ensures that all instance-level attributes can be properly initialized.
  10. Specifically, we do type inference based on attribute values...even
  11. if the attribute in question has already been typed using
  12. Python3-style annotations or ``torch.jit.annotate``. This means that
  13. setting an instance-level attribute to ``[]`` (for ``List``),
  14. ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
  15. information for us to properly initialize that attribute.
  16. An object of this class can walk a given ``nn.Module``'s AST and
  17. determine if it meets our requirements or not.
  18. Known limitations
  19. 1. We can only check the AST nodes for certain constructs; we can't
  20. ``eval`` arbitrary expressions. This means that function calls,
  21. class instantiations, and complex expressions that resolve to one of
  22. the "empty" values specified above will NOT be flagged as
  23. problematic.
  24. 2. We match on string literals, so if the user decides to use a
  25. non-standard import (e.g. `from typing import List as foo`), we
  26. won't catch it.
  27. Example:
  28. .. code-block:: python
  29. class M(torch.nn.Module):
  30. def fn(self):
  31. return []
  32. def __init__(self):
  33. super().__init__()
  34. self.x: List[int] = []
  35. def forward(self, x: List[int]):
  36. self.x = x
  37. return 1
  38. The above code will pass the ``AttributeTypeIsSupportedChecker``
  39. check since we have a function call in ``__init__``. However,
  40. it will still fail later with the ``RuntimeError`` "Tried to set
  41. nonexistent attribute: x. Did you forget to initialize it in
  42. __init__()?".
  43. Args:
  44. nn_module - The instance of ``torch.nn.Module`` whose
  45. ``__init__`` method we wish to check
  46. """
  47. def check(self, nn_module: torch.nn.Module) -> None:
  48. source_lines = inspect.getsource(nn_module.__class__.__init__)
  49. # Ignore comments no matter the indentation
  50. def is_useless_comment(line):
  51. line = line.strip()
  52. return line.startswith("#") and not line.startswith("# type:")
  53. source_lines = "\n".join(
  54. [l for l in source_lines.split("\n") if not is_useless_comment(l)]
  55. )
  56. # This AST only contains the `__init__` method of the nn.Module
  57. init_ast = ast.parse(textwrap.dedent(source_lines))
  58. # Get items annotated in the class body
  59. self.class_level_annotations = list(nn_module.__annotations__.keys())
  60. # Flag for later
  61. self.visiting_class_level_ann = False
  62. self.visit(init_ast)
  63. def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
  64. if ann_type == "List":
  65. # Assigning `[]` to a `List` type gives you a Node where
  66. # value=List(elts=[], ctx=Load())
  67. if not isinstance(node, ast.List):
  68. return False
  69. if node.elts:
  70. return False
  71. elif ann_type == "Dict":
  72. # Assigning `{}` to a `Dict` type gives you a Node where
  73. # value=Dict(keys=[], values=[])
  74. if not isinstance(node, ast.Dict):
  75. return False
  76. if node.keys:
  77. return False
  78. elif ann_type == "Optional":
  79. # Assigning `None` to an `Optional` type gives you a
  80. # Node where value=Constant(value=None, kind=None)
  81. if not isinstance(node, ast.Constant):
  82. return False
  83. if node.value: # type: ignore[attr-defined]
  84. return False
  85. return True
  86. def visit_Assign(self, node):
  87. """Store assignment state when assigning to a Call Node.
  88. If we're visiting a Call Node (the right-hand side of an
  89. assignment statement), we won't be able to check the variable
  90. that we're assigning to (the left-hand side of an assignment).
  91. Because of this, we need to store this state in visitAssign.
  92. (Luckily, we only have to do this if we're assigning to a Call
  93. Node, i.e. ``torch.jit.annotate``. If we're using normal Python
  94. annotations, we'll be visiting an AnnAssign Node, which has its
  95. target built in.)
  96. """
  97. try:
  98. if (
  99. isinstance(node.value, ast.Call)
  100. and node.targets[0].attr in self.class_level_annotations
  101. ):
  102. self.visiting_class_level_ann = True
  103. except AttributeError:
  104. return
  105. self.generic_visit(node)
  106. self.visiting_class_level_ann = False
  107. def visit_AnnAssign(self, node):
  108. """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
  109. It checks if it conforms to our attribute annotation rules."""
  110. # If we have a local variable
  111. try:
  112. if node.target.value.id != "self":
  113. return
  114. except AttributeError:
  115. return
  116. # If we have an attribute that's already been annotated at the
  117. # class level
  118. if node.target.attr in self.class_level_annotations:
  119. return
  120. # TODO @ansley: add `Union` once landed
  121. # NB: Even though `Tuple` is a "container", we don't want to
  122. # check for it here. `Tuple` functions as an type with an
  123. # "infinite" number of subtypes, in the sense that you can have
  124. # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
  125. # `Tuple[T2, T1]` and so on, and none of these subtypes can be
  126. # used in place of the other. Therefore, assigning an empty
  127. # tuple in `__init__` CORRECTLY means that that variable
  128. # cannot be reassigned later to a non-empty tuple. Same
  129. # deal with `NamedTuple`
  130. containers = {"List", "list", "Dict", "dict", "Optional"}
  131. # If we're not evaluating one of the specified problem types
  132. try:
  133. if node.annotation.value.id not in containers:
  134. return
  135. except AttributeError:
  136. # To evaluate a base type (`str`, `int`, etc.), we would
  137. # have needed to get the name through `node.annotation.id`
  138. # instead of `node.annotation.value.id`. Seems that we're
  139. # not evaluating one of our "containers"
  140. return
  141. # Check if the assigned variable is empty
  142. ann_type = node.annotation.value.id
  143. if not self._is_empty_container(node.value, ann_type):
  144. return
  145. warnings.warn(
  146. "The TorchScript type system doesn't support "
  147. "instance-level annotations on empty non-base "
  148. "types in `__init__`. Instead, either 1) use a "
  149. "type annotation in the class body, or 2) wrap "
  150. "the type in `torch.jit.Attribute`."
  151. )
  152. def visit_Call(self, node):
  153. """Determine if a Call node is 'torch.jit.annotate' in __init__.
  154. Visit a Call node in an ``nn.Module``'s ``__init__``
  155. method and determine if it's ``torch.jit.annotate``. If so,
  156. see if it conforms to our attribute annotation rules.
  157. """
  158. # If we have an attribute that's already been annotated at the
  159. # class level
  160. if self.visiting_class_level_ann:
  161. return
  162. # If this isn't a call to `torch.jit.annotate`
  163. try:
  164. if (
  165. node.func.value.value.id != "torch"
  166. or node.func.value.attr != "jit"
  167. or node.func.attr != "annotate"
  168. ):
  169. self.generic_visit(node)
  170. elif (
  171. node.func.value.value.id != "jit" or node.func.value.attr != "annotate"
  172. ):
  173. self.generic_visit(node)
  174. except AttributeError:
  175. # Looks like we didn't even have the right node structure
  176. # to check for `torch.jit.annotate` in the first place
  177. self.generic_visit(node)
  178. # Invariant: we have a `torch.jit.annotate` or a
  179. # `torch.annotate` call
  180. # A Call Node for `torch.jit.annotate` should have an `args`
  181. # list of length 2 where args[0] represents the annotation and
  182. # args[1] represents the actual value
  183. if len(node.args) != 2:
  184. return
  185. if not isinstance(node.args[0], ast.Subscript):
  186. return
  187. # See notes in `visit_AnnAssign` r.e. containers
  188. containers = {"List", "Dict", "Optional"}
  189. try:
  190. ann_type = node.args[0].value.id # type: ignore[attr-defined]
  191. except AttributeError:
  192. return
  193. if ann_type not in containers:
  194. return
  195. # Check if the assigned variable is empty
  196. if not self._is_empty_container(node.args[1], ann_type):
  197. return
  198. warnings.warn(
  199. "The TorchScript type system doesn't support "
  200. "instance-level annotations on empty non-base "
  201. "types in `__init__`. Instead, either 1) use a "
  202. "type annotation in the class body, or 2) wrap "
  203. "the type in `torch.jit.Attribute`."
  204. )