utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import inspect
  4. import sys
  5. from typing import Any, Callable, Dict, Iterable, Tuple
  6. import torch
  7. import torch._utils_internal as _utils_internal
  8. from torch import _C
  9. @dataclasses.dataclass
  10. class Kernel:
  11. """Models a (function, source location)"""
  12. func: Callable
  13. source: str
  14. def __call__(self, *args, **kwargs):
  15. return self.func(*args, **kwargs)
  16. class RegistrationHandle:
  17. """Does something when someone calls .destroy() on it"""
  18. def __init__(self, on_destroy: Callable):
  19. self._on_destroy = on_destroy
  20. def destroy(self) -> None:
  21. self._on_destroy()
  22. def get_source(stacklevel: int) -> str:
  23. """Get a string that represents the caller.
  24. Example: "/path/to/foo.py:42"
  25. Use stacklevel=1 to get the caller's source
  26. Use stacklevel=2 to get the caller's caller's source
  27. etc.
  28. """
  29. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  30. source = f"{frame.filename}:{frame.lineno}"
  31. return source
  32. def parse_namespace(qualname: str) -> Tuple[str, str]:
  33. splits = qualname.split("::")
  34. if len(splits) != 2:
  35. raise ValueError(
  36. f"Expected `qualname` to be of the form "
  37. f'"namespace::name", but got {qualname}. '
  38. f"The qualname passed to the torch.library APIs must consist "
  39. f"of a namespace and a name, e.g. aten::sin"
  40. )
  41. return splits[0], splits[1]
  42. def lookup_op(qualname: str) -> torch._ops.OpOverload:
  43. namespace, name = parse_namespace(qualname)
  44. if "." in name:
  45. name, overload = name.split(".")
  46. else:
  47. overload = "default"
  48. ns = getattr(torch.ops, namespace)
  49. packet = getattr(ns, name)
  50. return getattr(packet, overload)
  51. def is_builtin(op: torch._ops.OpOverload) -> bool:
  52. assert isinstance(op, torch._ops.OpOverload)
  53. return op.namespace in {"aten", "prim", "prims"}
  54. def is_functional_schema(schema: Any) -> bool:
  55. """Check if the schema is functional.
  56. An operator is functional if:
  57. - it does not mutate any of its inputs
  58. - it does not return a view on any of its inputs
  59. - it has at least one return
  60. """
  61. def is_functional(schema):
  62. if schema.is_mutable:
  63. return False
  64. rets = schema.returns
  65. is_non_mutating_view = len(rets) > 0 and any(
  66. r.alias_info is not None and not r.alias_info.is_write for r in rets
  67. )
  68. if is_non_mutating_view:
  69. return False
  70. if not schema.returns:
  71. return False
  72. return True
  73. if isinstance(schema, torch._C.FunctionSchema):
  74. return is_functional(schema)
  75. # Lazy import because not all PyTorch builds have torchgen
  76. from torchgen.model import FunctionSchema
  77. if isinstance(schema, str):
  78. schema = FunctionSchema.parse(schema)
  79. assert isinstance(schema, FunctionSchema)
  80. return is_functional(schema)
  81. # should be torch._C.JitType but that annotation is busted
  82. def is_tensorlist_like_type(typ: Any) -> bool:
  83. return (
  84. typ == _C.ListType(_C.TensorType.get())
  85. or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
  86. or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
  87. or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
  88. )
  89. # should be torch._C.JitType but that annotation is busted
  90. def is_tensor_like_type(typ: Any) -> bool:
  91. return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
  92. def mutates_and_returns_first_arg(op: torch._ops.OpOverload):
  93. """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
  94. TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
  95. but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
  96. Figure this out.
  97. Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
  98. """
  99. if op.namespace != "aten":
  100. return False
  101. schema = op._schema
  102. if not len(schema.returns) == 1:
  103. return False
  104. if schema.returns[0].alias_info is None:
  105. return False
  106. alias_set = schema.returns[0].alias_info.after_set
  107. if len(alias_set) != 1:
  108. return False
  109. loc = next(iter(alias_set))
  110. if len(schema.arguments) < 1:
  111. return False
  112. first_arg = schema.arguments[0]
  113. if first_arg.alias_info is None:
  114. return False
  115. if not first_arg.alias_info.is_write:
  116. return False
  117. alias_set = first_arg.alias_info.after_set
  118. if len(alias_set) != 1:
  119. return False
  120. if loc != next(iter(alias_set)):
  121. return False
  122. for arg in schema.arguments[1:]:
  123. if arg.alias_info is not None:
  124. return False
  125. return True
  126. def fill_defaults(schema, args, kwargs):
  127. new_args = []
  128. new_kwargs = {}
  129. for i in range(len(schema.arguments)):
  130. info = schema.arguments[i]
  131. if info.kwarg_only:
  132. if info.name in kwargs:
  133. new_kwargs[info.name] = kwargs[info.name]
  134. else:
  135. new_kwargs[info.name] = info.default_value
  136. else:
  137. if i < len(args):
  138. new_args.append(args[i])
  139. else:
  140. new_args.append(info.default_value)
  141. return tuple(new_args), new_kwargs
  142. def zip_schema(
  143. schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
  144. ) -> Iterable[Tuple[_C.Argument, Any]]:
  145. """zips schema.arguments and (args, kwargs) together.
  146. Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
  147. that is, kwargs must be keyword-only arguments and default values may be omitted.
  148. """
  149. assert len(schema.arguments) >= len(args) + len(kwargs)
  150. for i in range(len(schema.arguments)):
  151. info = schema.arguments[i]
  152. if info.kwarg_only:
  153. if info.name in kwargs:
  154. yield info, kwargs[info.name]
  155. continue
  156. if i >= len(args):
  157. # args that are equal to their default values are not populated
  158. # if they are followed by args that are equal to their defaults.
  159. # Skip these.
  160. continue
  161. yield info, args[i]
  162. return
  163. def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool:
  164. assert isinstance(op, torch._ops.OpOverload)
  165. if is_builtin(op):
  166. # We control the built-ins. These may (in rare cases)
  167. # do input metadata mutation (which we have banned on custom ops)
  168. return False
  169. schema = op._schema
  170. # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
  171. if not schema.is_mutable:
  172. return False
  173. if len(schema.returns) > 0:
  174. return False
  175. # If the op returns nothing, then it has a trivial fake impl.
  176. return True
  177. def requires_set_python_module() -> bool:
  178. """If an op was defined in C++ and extended from Python using the
  179. torch.library APIs, returns if we require that there have been a
  180. m.set_python_module("mylib.ops") call from C++ that associates
  181. the C++ op with a python module.
  182. """
  183. return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
  184. def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
  185. assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
  186. overload_types = []
  187. args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
  188. for a in args_flattened:
  189. # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
  190. # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
  191. # where in one case we only include tensors with the python key, and in another
  192. # we include **all** tensors.
  193. if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
  194. torch._C.DispatchKey.Python
  195. ):
  196. overload_types.append(type(a))
  197. # TODO: check that I got these args correct (in C++, we pass in "0000"??)
  198. return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  199. def has_kwarg_only_args(schema: _C.FunctionSchema):
  200. return any(a.kwarg_only for a in schema.arguments)
  201. def has_kwarg_only_tensors(schema: _C.FunctionSchema):
  202. for a in schema.arguments:
  203. if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
  204. continue
  205. if not a.kwarg_only:
  206. continue
  207. return True
  208. return False