refs.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # mypy: ignore-errors
  2. from torch.testing._internal.opinfo.core import (
  3. BinaryUfuncInfo,
  4. OpInfo,
  5. ReductionOpInfo,
  6. UnaryUfuncInfo,
  7. )
  8. # NOTE [Python References]
  9. # Python References emulate existing PyTorch operations, but can ultimately
  10. # be expressed in terms of "primitive" operations from torch._prims.
  11. #
  12. # These references are experimental.
  13. # See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577
  14. # for additional context.
  15. #
  16. # Python Reference OpInfos should be added to the python_ref_db list below.
  17. # Tests can opt-into running on these references by including
  18. # that list in the Sequence they pass to the @ops decorator.
  19. #
  20. # When a Python Reference OpInfo is constructed a pointer to an
  21. # existing OpInfo must be provided using the torch_opinfo_name kwarg.
  22. # The existing OpInfo with that name and no variant will be found
  23. # to inherit from.
  24. #
  25. # Instead of just inheriting the existing OpInfo's metadata, the
  26. # Python Reference OpInfos inherit the existing OpInfo's
  27. # construction arguments. These arguments can be overridden
  28. # by adding kwargs to the constructor.
  29. def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None):
  30. """
  31. Finds the OpInfo with the given name that has no variant name.
  32. """
  33. # NOTE: searching the global op_db doesn't work when OpInfos are split into
  34. # different modules, as otherwise the op_db will not be fully constructed
  35. # yet. So, instead the local op_db must be passed in explicitly.
  36. if op_db is None:
  37. from torch.testing._internal.common_methods_invocations import op_db
  38. for opinfo in op_db:
  39. if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name:
  40. return opinfo
  41. def _inherit_constructor_args(name, op, inherited, overrides):
  42. # inherits metadata
  43. common_kwargs = {
  44. "name": name,
  45. "op": op,
  46. "aliases": None, # TODO add a check for alias coverage
  47. "method_variant": None,
  48. "inplace_variant": None, # TODO: add a check for inplace coverage
  49. "supports_scripting": False,
  50. }
  51. # Acquires inherited kwargs
  52. kwargs = inherited.copy()
  53. # Fixes metadata
  54. if "kwargs" in kwargs:
  55. kwargs.update(kwargs["kwargs"])
  56. del kwargs["kwargs"]
  57. if "self" in kwargs:
  58. del kwargs["self"]
  59. if "__class__" in kwargs:
  60. del kwargs["__class__"]
  61. if "skips" in kwargs:
  62. del kwargs["skips"]
  63. if "decorators" in kwargs:
  64. del kwargs["decorators"]
  65. # Overrides metadata
  66. kwargs.update(common_kwargs)
  67. kwargs.update(overrides)
  68. # At the moment no prims support autograd, so we must not run autograd
  69. # tests e.g. when testing dtype support. Once we start writing autograd
  70. # formulas for prims this can be removed.
  71. kwargs["supports_autograd"] = False
  72. kwargs["supports_gradgrad"] = False
  73. kwargs["supports_fwgrad_bwgrad"] = False
  74. kwargs["supports_inplace_autograd"] = False
  75. kwargs["supports_forward_ad"] = False
  76. return kwargs
  77. class PythonRefInfo(OpInfo):
  78. """
  79. An OpInfo for a Python reference of an OpInfo base class operation.
  80. """
  81. def __init__(
  82. self,
  83. name, # the stringname of the callable Python reference
  84. *,
  85. op=None, # the function variant of the operation, populated as torch.<name> if None
  86. op_db=None, # The database of opinfos to search for the parent opinfo
  87. torch_opinfo_name, # the string name of the corresponding torch opinfo
  88. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  89. validate_view_consistency=True,
  90. **kwargs,
  91. ): # additional kwargs override kwargs inherited from the torch opinfo
  92. self.torch_opinfo_name = torch_opinfo_name
  93. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  94. self.torch_opinfo = _find_referenced_opinfo(
  95. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  96. )
  97. self.validate_view_consistency = validate_view_consistency
  98. assert isinstance(self.torch_opinfo, OpInfo)
  99. inherited = self.torch_opinfo._original_opinfo_args
  100. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  101. super().__init__(**ukwargs)
  102. class ReductionPythonRefInfo(ReductionOpInfo):
  103. """
  104. An OpInfo for a Python reference of an elementwise unary operation.
  105. """
  106. def __init__(
  107. self,
  108. name, # the stringname of the callable Python reference
  109. *,
  110. op=None, # the function variant of the operation, populated as torch.<name> if None
  111. op_db=None, # The database of opinfos to search for the parent opinfo
  112. torch_opinfo_name, # the string name of the corresponding torch opinfo
  113. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  114. **kwargs,
  115. ): # additional kwargs override kwargs inherited from the torch opinfo
  116. self.torch_opinfo_name = torch_opinfo_name
  117. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  118. self.torch_opinfo = _find_referenced_opinfo(
  119. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  120. )
  121. assert isinstance(self.torch_opinfo, ReductionOpInfo)
  122. inherited = self.torch_opinfo._original_reduction_args
  123. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  124. # See https://github.com/pytorch/pytorch/issues/77216
  125. self.validate_view_consistency = False
  126. super().__init__(**ukwargs)
  127. class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
  128. """
  129. An OpInfo for a Python reference of an elementwise unary operation.
  130. """
  131. def __init__(
  132. self,
  133. name, # the stringname of the callable Python reference
  134. *,
  135. op=None, # the function variant of the operation, populated as torch.<name> if None
  136. op_db=None, # The database of opinfos to search for the parent opinfo
  137. torch_opinfo_name, # the string name of the corresponding torch opinfo
  138. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  139. validate_view_consistency=True,
  140. **kwargs,
  141. ): # additional kwargs override kwargs inherited from the torch opinfo
  142. self.torch_opinfo_name = torch_opinfo_name
  143. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  144. self.torch_opinfo = _find_referenced_opinfo(
  145. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  146. )
  147. self.validate_view_consistency = validate_view_consistency
  148. assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
  149. inherited = self.torch_opinfo._original_unary_ufunc_args
  150. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  151. super().__init__(**ukwargs)
  152. class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
  153. """
  154. An OpInfo for a Python reference of an elementwise binary operation.
  155. """
  156. def __init__(
  157. self,
  158. name, # the stringname of the callable Python reference
  159. *,
  160. op=None, # the function variant of the operation, populated as torch.<name> if None
  161. op_db=None, # The database of opinfos to search for the parent opinfo
  162. torch_opinfo_name, # the string name of the corresponding torch opinfo
  163. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  164. **kwargs,
  165. ): # additional kwargs override kwargs inherited from the torch opinfo
  166. self.torch_opinfo_name = torch_opinfo_name
  167. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  168. self.torch_opinfo = _find_referenced_opinfo(
  169. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  170. )
  171. assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
  172. inherited = self.torch_opinfo._original_binary_ufunc_args
  173. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  174. super().__init__(**ukwargs)