_dataclass_impls.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # mypy: allow-untyped-defs
  2. # Functions for synthesizing magic methods for JIT-compiled dataclasses
  3. import ast
  4. import dataclasses
  5. import inspect
  6. import os
  7. from functools import partial
  8. from typing import Callable, Dict, List
  9. from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
  10. from torch._sources import ParsedDef, SourceContext
  11. def _get_fake_filename(cls, method_name):
  12. return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
  13. def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
  14. body = "\n".join(f" {b}" for b in body_lines)
  15. decl = f"def {name}{signature}:\n{body}"
  16. # Parse the function declaration
  17. try:
  18. py_ast = ast.parse(decl)
  19. except SyntaxError as e:
  20. # This should only happen if there's some unforeseeable change
  21. # in the dataclasses module that makes our synthesized code fail
  22. raise RuntimeError(
  23. f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
  24. "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
  25. ) from e
  26. fake_filename = _get_fake_filename(cls, name)
  27. # Parse the function
  28. return ParsedDef(
  29. py_ast,
  30. ctx=SourceContext(
  31. source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0
  32. ),
  33. source=decl,
  34. filename=fake_filename,
  35. file_lineno=0,
  36. )
  37. def synthesize__init__(cls) -> ParsedDef:
  38. # Supporting default factories in the way that people expect would sort of require us to
  39. # allow compiling lambda functions, which is not currently supported.
  40. if any(
  41. field.default_factory is not dataclasses.MISSING
  42. for field in dataclasses.fields(cls)
  43. ):
  44. raise NotImplementedError(
  45. "Default factory initializers are not supported in TorchScript dataclasses"
  46. )
  47. # Simply read off the generated __init__ signature from CPython's implementation. It'll be
  48. # almost correct except for InitVar annotations, which we need to handle specially.
  49. signature = inspect.signature(cls.__init__)
  50. # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
  51. # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
  52. init_vars: List[str] = []
  53. params = []
  54. for name, param in signature.parameters.items():
  55. ann = param.annotation
  56. if isinstance(ann, dataclasses.InitVar):
  57. # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
  58. init_vars.append(name)
  59. params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
  60. else:
  61. params.append(param)
  62. signature = signature.replace(parameters=params)
  63. body = [
  64. # Assign all attributes to self
  65. f"self.{field.name} = {field.name}"
  66. for field in dataclasses.fields(cls)
  67. if field.init and field.name not in init_vars
  68. ]
  69. # Call user's impl of __post_init__ if it exists
  70. if hasattr(cls, "__post_init__"):
  71. body.append("self.__post_init__(" + ", ".join(init_vars) + ")")
  72. return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature))
  73. # This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
  74. def synthesize__repr__(cls) -> ParsedDef:
  75. return compose_fn(
  76. cls,
  77. "__repr__",
  78. [
  79. f"return '{cls.__name__}("
  80. + ", ".join(
  81. [
  82. f"{field.name}=self.{field.name}"
  83. for field in dataclasses.fields(cls)
  84. if field.repr
  85. ]
  86. )
  87. + ")'"
  88. ],
  89. signature="(self) -> str",
  90. )
  91. def synthesize__hash__(cls) -> ParsedDef:
  92. return compose_fn(
  93. cls,
  94. "__hash__",
  95. [
  96. # This is just a placeholder to prevent compilation from failing; this won't even get called at
  97. # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
  98. "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
  99. ],
  100. signature="(self) -> int",
  101. )
  102. # Implementation for __eq__ and __ne__
  103. def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
  104. return synthesize_comparison(
  105. cls,
  106. name,
  107. allow_eq=True,
  108. raise_on_none=False,
  109. inner=[f"if val1 {converse} val2: return False"],
  110. )
  111. def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
  112. return synthesize_comparison(
  113. cls,
  114. name,
  115. allow_eq,
  116. raise_on_none=True,
  117. inner=[
  118. f"if val1 {op} val2: return True",
  119. f"elif val2 {op} val1: return False",
  120. ],
  121. )
  122. def synthesize_comparison(
  123. cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
  124. ) -> ParsedDef:
  125. body = []
  126. for field in dataclasses.fields(cls):
  127. if not field.compare:
  128. continue
  129. body.extend(
  130. [
  131. f"val1 = self.{field.name}",
  132. f"val2 = other.{field.name}",
  133. ]
  134. )
  135. body.extend(
  136. inner
  137. if not is_optional(field.type)
  138. else [
  139. # Type refinement for optional fields; we need this to avoid type errors from the interpreter
  140. "if val1 is not None and val2 is not None:",
  141. *[" " + line for line in inner],
  142. "elif (val1 is None) != (val2 is None):",
  143. f" raise TypeError('Cannot compare {cls.__name__} with None')"
  144. if raise_on_none
  145. else " return False",
  146. ]
  147. )
  148. body.append(f"return {allow_eq}")
  149. return compose_fn(
  150. cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool"
  151. )
  152. DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
  153. "__init__": synthesize__init__,
  154. "__repr__": synthesize__repr__,
  155. "__hash__": synthesize__hash__,
  156. "__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
  157. "__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
  158. "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
  159. "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
  160. "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
  161. "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
  162. }